Keygen problem when iterating over a loop

Hey all,

does maybe someone have an idea for the following problem?

I have class with the following functions

class concrete_cnn():

    def __init__(self, n_bits,...):
        self.n_bits = n_bits
        self.concrete_module = None
        ....

    def train_brevitas_model(self,...):
        #  1. QAT training
        ....
        # 2. compile model
        self.concrete_module = compile_brevitas_qat_model(self.brevitas_model, x_train[:100])

        bit_width = self.concrete_module.fhe_circuit.graph.maximum_integer_bit_width()
        print(f'WARNING: Maximum bit-width of 16 exceeded ({bit_width})') if bit_width > 16 else None

        return self

    def evaluate(self,...):
        # 1. Key generation
        self.concrete_module.fhe_circuit.keygen(force=True)
        ...
        # 2. predict on test set
        ...
        return y_pred

This class works fine for my given dataset and i don’t have any problems when running a single n_bit configuration (takes about 10 minutes for n_bits=6 without exceeding the max. accum. bit-width)

When I now want to log the runtime and performance for each possible n_bit configuration with a loop (see code below) my notebook gets stuck in the key generation for hours and finally crashes. This always happens at around bits/n_bits>=4, everything below was no problem. Also no indication of a RAM issue.

for bits in range(2,7):
    cnn = concrete_cnn(n_bits=bits)
    # train
    cnn = cnn.train_brevitas_model(x_train,...)
    # evaluate
    y_pred = cnn.evaluate(x_test,...)

here a screenshot of my Jupyter logs from vs code:

Hello @lstk ,
Would you also be able to provide the logs when compiling the problematic model ? You can get those by setting verbose=True in compile_brevitas_qat_model and then copy/paste them here ! This would help us a lot on determining the source of your issue :wink:

Thanks

This is from the iteration (n_bits=4) the jupyter kernel crashed:

Computation Graph
--------------------------------------------------------------------------------
 %0 = _inp_1                                                                                       # EncryptedTensor<uint3, shape=(1, 1, 28, 28)>        ∈ [0, 7]
 %1 = [[[[ 0  0  ...   2 -2]]]]                                                                    # ClearTensor<int4, shape=(3, 1, 5, 5)>               ∈ [-7, 7]            @ /1/Conv.conv
 %2 = conv2d(%0, %1, [0 0 0], pads=[0, 0, 0, 0], strides=(1, 1), dilations=(1, 1), group=1)        # EncryptedTensor<int10, shape=(1, 3, 24, 24)>        ∈ [-231, 300]        @ /1/Conv.conv
 %3 = subgraph(%2)                                                                                 # EncryptedTensor<uint4, shape=(1, 3, 24, 24)>        ∈ [0, 15]
 %4 = [[[[1 1]   ...   [1 1]]]]                                                                    # ClearTensor<uint1, shape=(3, 3, 2, 2)>              ∈ [0, 1]             @ /3/AveragePool.avgpool
 %5 = conv2d(%3, %4, [0 0 0], pads=[0, 0, 0, 0], strides=(2, 2), dilations=(1, 1), group=1)        # EncryptedTensor<uint6, shape=(1, 3, 12, 12)>        ∈ [0, 60]            @ /3/AveragePool.avgpool
 %6 = subgraph(%5)                                                                                 # EncryptedTensor<uint3, shape=(1, 3, 12, 12)>        ∈ [0, 7]
 %7 = reshape(%6, newshape=(1, 432))                                                               # EncryptedTensor<uint3, shape=(1, 432)>              ∈ [0, 7]
 %8 = subgraph(%7)                                                                                 # EncryptedTensor<uint3, shape=(1, 432)>              ∈ [0, 7]
 %9 = [[ 0  0] [ ... ] [-3  3]]                                                                    # ClearTensor<int4, shape=(432, 2)>                   ∈ [-6, 7]            @ /7/Gemm.matmul
%10 = matmul(%8, %9)                                                                               # EncryptedTensor<int11, shape=(1, 2)>                ∈ [-583, 594]        @ /7/Gemm.matmul
return %10

Subgraphs:

    %3 = subgraph(%2):

         %0 = input                            # EncryptedTensor<uint3, shape=(1, 3, 24, 24)>          @ /1/Conv.conv
         %1 = astype(%0, dtype=float64)        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %2 = 0                                # ClearScalar<uint1>
         %3 = add(%1, %2)                      # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %4 = [[[[0]]  [[0]]  [[0]]]]          # ClearTensor<uint1, shape=(1, 3, 1, 1)>
         %5 = subtract(%3, %4)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %6 = 0.00508421244397636              # ClearScalar<float64>
         %7 = multiply(%6, %5)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %8 = [[[[ 6.755 ... 34e-01]]]]        # ClearTensor<float32, shape=(1, 3, 1, 1)>
         %9 = add(%7, %8)                      # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %10 = 0                                # ClearScalar<uint1>
        %11 = maximum(%9, %10)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %12 = 0.03629405                       # ClearScalar<float32>
        %13 = divide(%11, %12)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %14 = 0.0                              # ClearScalar<float32>
        %15 = add(%13, %14)                    # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %16 = 0.0                              # ClearScalar<float64>
        %17 = 15.0                             # ClearScalar<float64>
        %18 = clip(%15, %16, %17)              # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %19 = rint(%18)                        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %20 = 0.0                              # ClearScalar<float32>
        %21 = subtract(%19, %20)               # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %22 = 0.03629405                       # ClearScalar<float32>
        %23 = multiply(%21, %22)               # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %24 = 0.03629405051469803              # ClearScalar<float64>
        %25 = divide(%23, %24)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %26 = 0                                # ClearScalar<uint1>
        %27 = add(%25, %26)                    # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %28 = rint(%27)                        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %29 = astype(%28, dtype=int_)          # EncryptedTensor<uint1, shape=(1, 3, 24, 24)>
        return %29

    %6 = subgraph(%5):

         %0 = input                            # EncryptedTensor<uint1, shape=(1, 3, 12, 12)>          @ /3/AveragePool.avgpool
         %1 = astype(%0, dtype=float64)        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %2 = 0.25                             # ClearScalar<float64>
         %3 = multiply(%1, %2)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %4 = 0                                # ClearScalar<uint1>
         %5 = subtract(%3, %4)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %6 = 0.03629405051469803              # ClearScalar<float64>
         %7 = multiply(%5, %6)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %8 = 0.07292149                       # ClearScalar<float32>
         %9 = divide(%7, %8)                   # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %10 = 0.0                              # ClearScalar<float32>
        %11 = add(%9, %10)                     # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %12 = -8.0                             # ClearScalar<float64>
        %13 = 7.0                              # ClearScalar<float64>
        %14 = clip(%11, %12, %13)              # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %15 = rint(%14)                        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %16 = 0.0                              # ClearScalar<float32>
        %17 = subtract(%15, %16)               # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %18 = 0.07292149                       # ClearScalar<float32>
        %19 = multiply(%17, %18)               # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %20 = 0.07292149215936661              # ClearScalar<float64>
        %21 = divide(%19, %20)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %22 = 0                                # ClearScalar<uint1>
        %23 = add(%21, %22)                    # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %24 = rint(%23)                        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %25 = astype(%24, dtype=int_)          # EncryptedTensor<uint1, shape=(1, 3, 12, 12)>
        return %25

    %8 = subgraph(%7):

         %0 = input                          # EncryptedTensor<uint1, shape=(1, 432)>
         %1 = 0.0                            # ClearScalar<float64>
         %2 = subtract(%0, %1)               # EncryptedTensor<float64, shape=(1, 432)>
         %3 = 0.07292149215936661            # ClearScalar<float64>
         %4 = multiply(%3, %2)               # EncryptedTensor<float64, shape=(1, 432)>
         %5 = 0.0685318                      # ClearScalar<float32>
         %6 = divide(%4, %5)                 # EncryptedTensor<float64, shape=(1, 432)>
         %7 = 0.0                            # ClearScalar<float32>
         %8 = add(%6, %7)                    # EncryptedTensor<float64, shape=(1, 432)>
         %9 = -8.0                           # ClearScalar<float64>
        %10 = 7.0                            # ClearScalar<float64>
        %11 = clip(%8, %9, %10)              # EncryptedTensor<float64, shape=(1, 432)>
        %12 = rint(%11)                      # EncryptedTensor<float64, shape=(1, 432)>
        %13 = 0.0                            # ClearScalar<float32>
        %14 = subtract(%12, %13)             # EncryptedTensor<float64, shape=(1, 432)>
        %15 = 0.0685318                      # ClearScalar<float32>
        %16 = multiply(%14, %15)             # EncryptedTensor<float64, shape=(1, 432)>
        %17 = 0.06853179633617401            # ClearScalar<float64>
        %18 = divide(%16, %17)               # EncryptedTensor<float64, shape=(1, 432)>
        %19 = 0                              # ClearScalar<uint1>
        %20 = add(%18, %19)                  # EncryptedTensor<float64, shape=(1, 432)>
        %21 = rint(%20)                      # EncryptedTensor<float64, shape=(1, 432)>
        %22 = astype(%21, dtype=int_)        # EncryptedTensor<uint1, shape=(1, 432)>
        return %22
--------------------------------------------------------------------------------

Optimizer
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Statistics
--------------------------------------------------------------------------------
size_of_secret_keys: 1159152
size_of_bootstrap_keys: 10296742752
size_of_keyswitch_keys: 8957017848
size_of_inputs: 822089856
size_of_outputs: 32784
p_error: 8.985933003105681e-13
global_p_error: 2.1656695732876227e-09
complexity: 33817488916176.0
programmable_bootstrap_count: 2592
programmable_bootstrap_count_per_parameter: {
    BootstrapKeyParam(polynomial_size=131072, glwe_dimension=1, input_lwe_dimension=1147, level=2, base_log=14, variance=4.70197740328915e-38): 1728
    BootstrapKeyParam(polynomial_size=8192, glwe_dimension=1, input_lwe_dimension=1542, level=1, base_log=22, variance=4.70197740328915e-38): 432
    BootstrapKeyParam(polynomial_size=1024, glwe_dimension=2, input_lwe_dimension=887, level=2, base_log=15, variance=9.940977002694397e-32): 432
}
key_switch_count: 2592
key_switch_count_per_parameter: {
    KeyswitchKeyParam(level=6, base_log=4, variance=2.669736518019816e-17): 1728
    KeyswitchKeyParam(level=1, base_log=19, variance=1.2611069970889797e-23): 432
    KeyswitchKeyParam(level=2, base_log=7, variance=3.8925225717720396e-13): 432
}
packing_key_switch_count: 0
clear_addition_count: 1728
clear_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 1728
}
encrypted_addition_count: 49248
encrypted_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 48384
    LweSecretKeyParam(dimension=2048): 864
}
clear_multiplication_count: 49248
clear_multiplication_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 48384
    LweSecretKeyParam(dimension=2048): 864
}
encrypted_negation_count: 0
--------------------------------------------------------------------------------


Hello @lstk ,
I indeed can see why the key generation took forever but I am still unsure if this is expected, I will continue investigating on that and keep you updated on it. Additionally, I’ll see if we could raise an actual error instead of crashing like if this issue happens again.

In the mean time, I advise you to look around the rounding_threshold_bits from the compile_brevitas_qat_model method. Setting it to 6 bits for example will most likely make the key generation as well as the FHE execution faster without impacting the model’s performance score too much (else, try 7 or 8 at most). You can find more information on that feature in this documentation section :slightly_smiling_face:

Hope that helps !

1 Like

Alright! Thanks as always for the fast support! :smiley:

1 Like

Hello again @lstk,
I just realized that I forgot to also ask if you could share the model’s MLIR by any chance ! You can generate it by setting show_mlir to True in the compile_brevitas_qat_model and copy paste it here, or by running the following after the compilation :

with open("mlir.txt", "w") as mlir:
    mlir.write(circuit.mlir)

Thanks a lot !

sure. here it is

MLIR
--------------------------------------------------------------------------------
module {
  func.func @main(%arg0: tensor<1x1x28x28x!FHE.eint<8>>) -> tensor<1x2x!FHE.esint<9>> {
    %cst = arith.constant dense<[[[[0, 0, -1, 1, 0], [1, 3, 1, 2, 3], [-1, 1, 3, 3, 2], [-1, 0, -1, -1, -1], [1, 1, 0, -1, 0]]], [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], [[[0, 2, 2, 1, 0], [1, 2, 1, 1, -1], [1, 2, 2, 0, 0], [1, 2, 2, 1, 0], [2, 2, 3, 1, 1]]]]> : tensor<3x1x5x5xi9>
    %0 = "FHELinalg.to_signed"(%arg0) : (tensor<1x1x28x28x!FHE.eint<8>>) -> tensor<1x1x28x28x!FHE.esint<8>>
    %1 = "FHELinalg.conv2d"(%0, %cst) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x1x28x28x!FHE.esint<8>>, tensor<3x1x5x5xi9>) -> tensor<1x3x24x24x!FHE.esint<8>>
    %cst_0 = arith.constant dense<"0x0...0"> : tensor<1x3x24x24xindex>
    %cst_1 = arith.constant dense<"0x0...0"> : tensor<3x256xi64>
    %2 = "FHELinalg.apply_mapped_lookup_table"(%1, %cst_1, %cst_0) : (tensor<1x3x24x24x!FHE.esint<8>>, tensor<3x256xi64>, tensor<1x3x24x24xindex>) -> tensor<1x3x24x24x!FHE.eint<5>>
    %cst_2 = arith.constant dense<[[[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[1, 1], [1, 1]]]]> : tensor<3x3x2x2xi6>
    %3 = "FHELinalg.conv2d"(%2, %cst_2) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<2> : tensor<2xi64>} : (tensor<1x3x24x24x!FHE.eint<5>>, tensor<3x3x2x2xi6>) -> tensor<1x3x12x12x!FHE.eint<5>>
    %cst_3 = arith.constant dense<[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]> : tensor<32xi64>
    %4 = "FHELinalg.apply_lookup_table"(%3, %cst_3) : (tensor<1x3x12x12x!FHE.eint<5>>, tensor<32xi64>) -> tensor<1x3x12x12x!FHE.eint<2>>
    %collapsed = tensor.collapse_shape %4 [[0], [1, 2, 3]] : tensor<1x3x12x12x!FHE.eint<2>> into tensor<1x432x!FHE.eint<2>>
    %cst_4 = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>
    %5 = "FHELinalg.apply_lookup_table"(%collapsed, %cst_4) : (tensor<1x432x!FHE.eint<2>>, tensor<4xi64>) -> tensor<1x432x!FHE.eint<9>>
    %cst_5 = arith.constant dense<"0x01...FF03"> : tensor<432x2xi10>
    %6 = "FHELinalg.to_signed"(%5) : (tensor<1x432x!FHE.eint<9>>) -> tensor<1x432x!FHE.esint<9>>
    %7 = "FHELinalg.matmul_eint_int"(%6, %cst_5) : (tensor<1x432x!FHE.esint<9>>, tensor<432x2xi10>) -> tensor<1x2x!FHE.esint<9>>
    return %7 : tensor<1x2x!FHE.esint<9>>
  }
}
--------------------------------------------------------------------------------
1 Like

Awesome thanks a lot !

1 Like