Slow right shift (table lookups)

Hello,

I read that “TFHE also has a fast primitive for right bit-shift”. Is that supported in concrete or are bit shifts always a table lookup? Since fast removal of LSBs is supported, support for fast right-shifts would feel natural.
When I test left and right shifts they are translated to table lookups:

My circuit:

  1         @fhe.compiler({"query": "encrypted"})
  2         def lookup(query):
  3             index = query >> 1
  4             return index

My Computation Graph:

----------------------------------------------------------------------------------------------------------------------------
%0 = query                      # EncryptedTensor<uint2, shape=(378,)>        ∈ [0, 2]
%1 = 1                          # ClearScalar<uint1>                          ∈ [1, 1]
%2 = right_shift(%0, %1)        # EncryptedTensor<uint1, shape=(378,)>        ∈ [0, 1]
return %2
----------------------------------------------------------------------------------------------------------------------------

My MLIR:

module {
  func.func @main(%arg0: tensor<378x!FHE.eint<2>>) -> tensor<378x!FHE.eint<1>> {
    %c1_i2 = arith.constant 1 : i2
    %cst = arith.constant dense<[0, 0, 1, 1]> : tensor<4xi64>
    %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<378x!FHE.eint<2>>, tensor<4xi64>) -> tensor<378x!FHE.eint<1>>
    return %0 : tensor<378x!FHE.eint<1>>
  }
}

Hi @lulu,

At the moment, it works but it’s not optimal. In the near future, it might be :slightly_smiling_face:

Hope this answers your question!

Ok, thank you for your answer.

is there any update about this “fast primitive for right bit-shift”? I cannot find anything in the documentation on how to enable it and right shift seems to be always translated to LUT

Hi,

Shift operation are not optimized yet, but you can rely on bits extraction primitive which can be faster.
E.g.

 # extract bits 1 to 3, ignoring all other bits
index = fhe.bits(query)[1:3]

fhe.bits(...)[...] is converted to either a plain TLU or sequence of bit extraction depending on which one is the faster and other constraints.
E.g.

%0 = query                # EncryptedScalar<uint5>        ∈ [0, 31]
%1 = bits(%0)[1:3]        # EncryptedScalar<uint2>        ∈ [0, 3]
return %1

is converted to

mlir module {
  func.func @lookup(%arg0: !FHE.eint<5>) -> !FHE.eint<2> {
    %0 = "FHE.lsb"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<5>
    %1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
    %2 = "FHE.reinterpret_precision"(%1) : (!FHE.eint<5>) -> !FHE.eint<4>
    %3 = "FHE.lsb"(%2) : (!FHE.eint<4>) -> !FHE.eint<4>
    %4 = "FHE.sub_eint"(%2, %3) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
    %5 = "FHE.reinterpret_precision"(%4) : (!FHE.eint<4>) -> !FHE.eint<3>
    %c4_i5 = arith.constant 4 : i5
    %6 = "FHE.mul_eint_int"(%3, %c4_i5) : (!FHE.eint<4>, i5) -> !FHE.eint<4>
    %7 = "FHE.reinterpret_precision"(%6) : (!FHE.eint<4>) -> !FHE.eint<2>
    %8 = "FHE.lsb"(%5) : (!FHE.eint<3>) -> !FHE.eint<3>
    %9 = "FHE.sub_eint"(%5, %8) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
    %10 = "FHE.reinterpret_precision"(%9) : (!FHE.eint<3>) -> !FHE.eint<2>
    %c4_i4 = arith.constant 4 : i4
    %11 = "FHE.mul_eint_int"(%8, %c4_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
    %12 = "FHE.reinterpret_precision"(%11) : (!FHE.eint<3>) -> !FHE.eint<2>
    %13 = "FHE.add_eint"(%7, %12) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
    return %13 : !FHE.eint<2>
  }
}

But

%0 = query                # EncryptedScalar<uint4>        ∈ [0, 15]
%1 = bits(%0)[1:3]        # EncryptedScalar<uint2>        ∈ [0, 3]
return %1

is converted to:

mlir module {
  func.func @lookup(%arg0: !FHE.eint<4>) -> !FHE.eint<2> {
    %cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 0, 0, 1, 1, 2, 2, 3, 3]> : tensor<16xi64>
    %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<2>
    return %0 : !FHE.eint<2>
  }
}

since TLU with 4-bits input is faster than the extraction of 3 bits (3 lsb mlir operation).