This isn’t being used in ConcreteML, but can be used when manipulating the compiler. This feature could potentially be used later in CML.
What this feature does is to break an integer of N
bits into N/Nchunk
chunks of Nchunk
(chunkwidths) bits, and of course rewrite the computation to work on chunks instead (why we see the increase of operations in the result MLIR below).
“What’s the purpose?” someone might ask: and it’s basically to use smaller parameters, thus faster execution. Even if we will perform much more operations, the cost of performing them can be lower than the initial computation for some use cases (computation/parameters).
On my local setup I can for example do:
$ concretecompiler --action=dump-fhe --chunk-integers laser.mlir
where laser.mlir
contains:
func.func @add_eint(%arg0: !FHE.eint<8>, %arg1: !FHE.eint<8>) -> !FHE.eint<8> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<8>, !FHE.eint<8>) -> (!FHE.eint<8>)
return %1: !FHE.eint<8>
}
and the output will be:
module {
func.func @add_eint(%arg0: tensor<4x!FHE.eint<4>>, %arg1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant dense<[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]> : tensor<16xi64>
%c4_i5 = arith.constant 4 : i5
%c0 = arith.constant 0 : index
%0 = "FHE.zero"() {MANP = 1 : ui1} : () -> !FHE.eint<4>
%1 = "FHE.zero_tensor"() {MANP = 1 : ui1} : () -> tensor<4x!FHE.eint<4>>
%extracted = tensor.extract %arg0[%c0] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%extracted_0 = tensor.extract %arg1[%c0] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%2 = "FHE.add_eint"(%extracted, %extracted_0) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%3 = "FHE.add_eint"(%2, %0) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%4 = "FHE.apply_lookup_table"(%3, %cst) {MANP = 1 : ui1} : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%5 = "FHE.mul_eint_int"(%4, %c4_i5) {MANP = 4 : ui11} : (!FHE.eint<4>, i5) -> !FHE.eint<4>
%6 = "FHE.sub_eint"(%3, %5) {MANP = 5 : ui13} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%inserted = tensor.insert %6 into %1[%c0] {MANP = 5 : ui13} : tensor<4x!FHE.eint<4>>
%extracted_1 = tensor.extract %arg0[%c1] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%extracted_2 = tensor.extract %arg1[%c1] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%7 = "FHE.add_eint"(%extracted_1, %extracted_2) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%8 = "FHE.add_eint"(%7, %0) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%9 = "FHE.apply_lookup_table"(%8, %cst) {MANP = 1 : ui1} : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%10 = "FHE.mul_eint_int"(%9, %c4_i5) {MANP = 4 : ui11} : (!FHE.eint<4>, i5) -> !FHE.eint<4>
%11 = "FHE.sub_eint"(%8, %10) {MANP = 5 : ui13} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%inserted_3 = tensor.insert %11 into %inserted[%c1] {MANP = 5 : ui13} : tensor<4x!FHE.eint<4>>
%extracted_4 = tensor.extract %arg0[%c2] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%extracted_5 = tensor.extract %arg1[%c2] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%12 = "FHE.add_eint"(%extracted_4, %extracted_5) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%13 = "FHE.add_eint"(%12, %0) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%14 = "FHE.apply_lookup_table"(%13, %cst) {MANP = 1 : ui1} : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%15 = "FHE.mul_eint_int"(%14, %c4_i5) {MANP = 4 : ui11} : (!FHE.eint<4>, i5) -> !FHE.eint<4>
%16 = "FHE.sub_eint"(%13, %15) {MANP = 5 : ui13} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%inserted_6 = tensor.insert %16 into %inserted_3[%c2] {MANP = 5 : ui13} : tensor<4x!FHE.eint<4>>
%extracted_7 = tensor.extract %arg0[%c3] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%extracted_8 = tensor.extract %arg1[%c3] {MANP = 1 : ui1} : tensor<4x!FHE.eint<4>>
%17 = "FHE.add_eint"(%extracted_7, %extracted_8) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%18 = "FHE.add_eint"(%17, %0) {MANP = 2 : ui3} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%19 = "FHE.apply_lookup_table"(%18, %cst) {MANP = 1 : ui1} : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%20 = "FHE.mul_eint_int"(%19, %c4_i5) {MANP = 4 : ui11} : (!FHE.eint<4>, i5) -> !FHE.eint<4>
%21 = "FHE.sub_eint"(%18, %20) {MANP = 5 : ui13} : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%inserted_9 = tensor.insert %21 into %inserted_6[%c3] {MANP = 5 : ui13} : tensor<4x!FHE.eint<4>>
return %inserted_9 : tensor<4x!FHE.eint<4>>
}
}