Keygen problem when iterating over a loop

Hey all,

does maybe someone have an idea for the following problem?

I have class with the following functions

class concrete_cnn():

    def __init__(self, n_bits,...):
        self.n_bits = n_bits
        self.concrete_module = None
        ....

    def train_brevitas_model(self,...):
        #  1. QAT training
        ....
        # 2. compile model
        self.concrete_module = compile_brevitas_qat_model(self.brevitas_model, x_train[:100])

        bit_width = self.concrete_module.fhe_circuit.graph.maximum_integer_bit_width()
        print(f'WARNING: Maximum bit-width of 16 exceeded ({bit_width})') if bit_width > 16 else None

        return self

    def evaluate(self,...):
        # 1. Key generation
        self.concrete_module.fhe_circuit.keygen(force=True)
        ...
        # 2. predict on test set
        ...
        return y_pred

This class works fine for my given dataset and i don’t have any problems when running a single n_bit configuration (takes about 10 minutes for n_bits=6 without exceeding the max. accum. bit-width)

When I now want to log the runtime and performance for each possible n_bit configuration with a loop (see code below) my notebook gets stuck in the key generation for hours and finally crashes. This always happens at around bits/n_bits>=4, everything below was no problem. Also no indication of a RAM issue.

for bits in range(2,7):
    cnn = concrete_cnn(n_bits=bits)
    # train
    cnn = cnn.train_brevitas_model(x_train,...)
    # evaluate
    y_pred = cnn.evaluate(x_test,...)

here a screenshot of my Jupyter logs from vs code:

Hello @lstk ,
Would you also be able to provide the logs when compiling the problematic model ? You can get those by setting verbose=True in compile_brevitas_qat_model and then copy/paste them here ! This would help us a lot on determining the source of your issue :wink:

Thanks

This is from the iteration (n_bits=4) the jupyter kernel crashed:

Computation Graph
--------------------------------------------------------------------------------
 %0 = _inp_1                                                                                       # EncryptedTensor<uint3, shape=(1, 1, 28, 28)>        ∈ [0, 7]
 %1 = [[[[ 0  0  ...   2 -2]]]]                                                                    # ClearTensor<int4, shape=(3, 1, 5, 5)>               ∈ [-7, 7]            @ /1/Conv.conv
 %2 = conv2d(%0, %1, [0 0 0], pads=[0, 0, 0, 0], strides=(1, 1), dilations=(1, 1), group=1)        # EncryptedTensor<int10, shape=(1, 3, 24, 24)>        ∈ [-231, 300]        @ /1/Conv.conv
 %3 = subgraph(%2)                                                                                 # EncryptedTensor<uint4, shape=(1, 3, 24, 24)>        ∈ [0, 15]
 %4 = [[[[1 1]   ...   [1 1]]]]                                                                    # ClearTensor<uint1, shape=(3, 3, 2, 2)>              ∈ [0, 1]             @ /3/AveragePool.avgpool
 %5 = conv2d(%3, %4, [0 0 0], pads=[0, 0, 0, 0], strides=(2, 2), dilations=(1, 1), group=1)        # EncryptedTensor<uint6, shape=(1, 3, 12, 12)>        ∈ [0, 60]            @ /3/AveragePool.avgpool
 %6 = subgraph(%5)                                                                                 # EncryptedTensor<uint3, shape=(1, 3, 12, 12)>        ∈ [0, 7]
 %7 = reshape(%6, newshape=(1, 432))                                                               # EncryptedTensor<uint3, shape=(1, 432)>              ∈ [0, 7]
 %8 = subgraph(%7)                                                                                 # EncryptedTensor<uint3, shape=(1, 432)>              ∈ [0, 7]
 %9 = [[ 0  0] [ ... ] [-3  3]]                                                                    # ClearTensor<int4, shape=(432, 2)>                   ∈ [-6, 7]            @ /7/Gemm.matmul
%10 = matmul(%8, %9)                                                                               # EncryptedTensor<int11, shape=(1, 2)>                ∈ [-583, 594]        @ /7/Gemm.matmul
return %10

Subgraphs:

    %3 = subgraph(%2):

         %0 = input                            # EncryptedTensor<uint3, shape=(1, 3, 24, 24)>          @ /1/Conv.conv
         %1 = astype(%0, dtype=float64)        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %2 = 0                                # ClearScalar<uint1>
         %3 = add(%1, %2)                      # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %4 = [[[[0]]  [[0]]  [[0]]]]          # ClearTensor<uint1, shape=(1, 3, 1, 1)>
         %5 = subtract(%3, %4)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %6 = 0.00508421244397636              # ClearScalar<float64>
         %7 = multiply(%6, %5)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
         %8 = [[[[ 6.755 ... 34e-01]]]]        # ClearTensor<float32, shape=(1, 3, 1, 1)>
         %9 = add(%7, %8)                      # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %10 = 0                                # ClearScalar<uint1>
        %11 = maximum(%9, %10)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %12 = 0.03629405                       # ClearScalar<float32>
        %13 = divide(%11, %12)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %14 = 0.0                              # ClearScalar<float32>
        %15 = add(%13, %14)                    # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %16 = 0.0                              # ClearScalar<float64>
        %17 = 15.0                             # ClearScalar<float64>
        %18 = clip(%15, %16, %17)              # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %19 = rint(%18)                        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %20 = 0.0                              # ClearScalar<float32>
        %21 = subtract(%19, %20)               # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %22 = 0.03629405                       # ClearScalar<float32>
        %23 = multiply(%21, %22)               # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %24 = 0.03629405051469803              # ClearScalar<float64>
        %25 = divide(%23, %24)                 # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %26 = 0                                # ClearScalar<uint1>
        %27 = add(%25, %26)                    # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %28 = rint(%27)                        # EncryptedTensor<float64, shape=(1, 3, 24, 24)>
        %29 = astype(%28, dtype=int_)          # EncryptedTensor<uint1, shape=(1, 3, 24, 24)>
        return %29

    %6 = subgraph(%5):

         %0 = input                            # EncryptedTensor<uint1, shape=(1, 3, 12, 12)>          @ /3/AveragePool.avgpool
         %1 = astype(%0, dtype=float64)        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %2 = 0.25                             # ClearScalar<float64>
         %3 = multiply(%1, %2)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %4 = 0                                # ClearScalar<uint1>
         %5 = subtract(%3, %4)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %6 = 0.03629405051469803              # ClearScalar<float64>
         %7 = multiply(%5, %6)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
         %8 = 0.07292149                       # ClearScalar<float32>
         %9 = divide(%7, %8)                   # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %10 = 0.0                              # ClearScalar<float32>
        %11 = add(%9, %10)                     # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %12 = -8.0                             # ClearScalar<float64>
        %13 = 7.0                              # ClearScalar<float64>
        %14 = clip(%11, %12, %13)              # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %15 = rint(%14)                        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %16 = 0.0                              # ClearScalar<float32>
        %17 = subtract(%15, %16)               # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %18 = 0.07292149                       # ClearScalar<float32>
        %19 = multiply(%17, %18)               # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %20 = 0.07292149215936661              # ClearScalar<float64>
        %21 = divide(%19, %20)                 # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %22 = 0                                # ClearScalar<uint1>
        %23 = add(%21, %22)                    # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %24 = rint(%23)                        # EncryptedTensor<float64, shape=(1, 3, 12, 12)>
        %25 = astype(%24, dtype=int_)          # EncryptedTensor<uint1, shape=(1, 3, 12, 12)>
        return %25

    %8 = subgraph(%7):

         %0 = input                          # EncryptedTensor<uint1, shape=(1, 432)>
         %1 = 0.0                            # ClearScalar<float64>
         %2 = subtract(%0, %1)               # EncryptedTensor<float64, shape=(1, 432)>
         %3 = 0.07292149215936661            # ClearScalar<float64>
         %4 = multiply(%3, %2)               # EncryptedTensor<float64, shape=(1, 432)>
         %5 = 0.0685318                      # ClearScalar<float32>
         %6 = divide(%4, %5)                 # EncryptedTensor<float64, shape=(1, 432)>
         %7 = 0.0                            # ClearScalar<float32>
         %8 = add(%6, %7)                    # EncryptedTensor<float64, shape=(1, 432)>
         %9 = -8.0                           # ClearScalar<float64>
        %10 = 7.0                            # ClearScalar<float64>
        %11 = clip(%8, %9, %10)              # EncryptedTensor<float64, shape=(1, 432)>
        %12 = rint(%11)                      # EncryptedTensor<float64, shape=(1, 432)>
        %13 = 0.0                            # ClearScalar<float32>
        %14 = subtract(%12, %13)             # EncryptedTensor<float64, shape=(1, 432)>
        %15 = 0.0685318                      # ClearScalar<float32>
        %16 = multiply(%14, %15)             # EncryptedTensor<float64, shape=(1, 432)>
        %17 = 0.06853179633617401            # ClearScalar<float64>
        %18 = divide(%16, %17)               # EncryptedTensor<float64, shape=(1, 432)>
        %19 = 0                              # ClearScalar<uint1>
        %20 = add(%18, %19)                  # EncryptedTensor<float64, shape=(1, 432)>
        %21 = rint(%20)                      # EncryptedTensor<float64, shape=(1, 432)>
        %22 = astype(%21, dtype=int_)        # EncryptedTensor<uint1, shape=(1, 432)>
        return %22
--------------------------------------------------------------------------------

Optimizer
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Statistics
--------------------------------------------------------------------------------
size_of_secret_keys: 1159152
size_of_bootstrap_keys: 10296742752
size_of_keyswitch_keys: 8957017848
size_of_inputs: 822089856
size_of_outputs: 32784
p_error: 8.985933003105681e-13
global_p_error: 2.1656695732876227e-09
complexity: 33817488916176.0
programmable_bootstrap_count: 2592
programmable_bootstrap_count_per_parameter: {
    BootstrapKeyParam(polynomial_size=131072, glwe_dimension=1, input_lwe_dimension=1147, level=2, base_log=14, variance=4.70197740328915e-38): 1728
    BootstrapKeyParam(polynomial_size=8192, glwe_dimension=1, input_lwe_dimension=1542, level=1, base_log=22, variance=4.70197740328915e-38): 432
    BootstrapKeyParam(polynomial_size=1024, glwe_dimension=2, input_lwe_dimension=887, level=2, base_log=15, variance=9.940977002694397e-32): 432
}
key_switch_count: 2592
key_switch_count_per_parameter: {
    KeyswitchKeyParam(level=6, base_log=4, variance=2.669736518019816e-17): 1728
    KeyswitchKeyParam(level=1, base_log=19, variance=1.2611069970889797e-23): 432
    KeyswitchKeyParam(level=2, base_log=7, variance=3.8925225717720396e-13): 432
}
packing_key_switch_count: 0
clear_addition_count: 1728
clear_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 1728
}
encrypted_addition_count: 49248
encrypted_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 48384
    LweSecretKeyParam(dimension=2048): 864
}
clear_multiplication_count: 49248
clear_multiplication_count_per_parameter: {
    LweSecretKeyParam(dimension=131072): 48384
    LweSecretKeyParam(dimension=2048): 864
}
encrypted_negation_count: 0
--------------------------------------------------------------------------------


Hello @lstk ,
I indeed can see why the key generation took forever but I am still unsure if this is expected, I will continue investigating on that and keep you updated on it. Additionally, I’ll see if we could raise an actual error instead of crashing like if this issue happens again.

In the mean time, I advise you to look around the rounding_threshold_bits from the compile_brevitas_qat_model method. Setting it to 6 bits for example will most likely make the key generation as well as the FHE execution faster without impacting the model’s performance score too much (else, try 7 or 8 at most). You can find more information on that feature in this documentation section :slightly_smiling_face:

Hope that helps !

1 Like

Alright! Thanks as always for the fast support! :smiley:

1 Like

Hello again @lstk,
I just realized that I forgot to also ask if you could share the model’s MLIR by any chance ! You can generate it by setting show_mlir to True in the compile_brevitas_qat_model and copy paste it here, or by running the following after the compilation :

with open("mlir.txt", "w") as mlir:
    mlir.write(circuit.mlir)

Thanks a lot !

sure. here it is

MLIR
--------------------------------------------------------------------------------
module {
  func.func @main(%arg0: tensor<1x1x28x28x!FHE.eint<8>>) -> tensor<1x2x!FHE.esint<9>> {
    %cst = arith.constant dense<[[[[0, 0, -1, 1, 0], [1, 3, 1, 2, 3], [-1, 1, 3, 3, 2], [-1, 0, -1, -1, -1], [1, 1, 0, -1, 0]]], [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], [[[0, 2, 2, 1, 0], [1, 2, 1, 1, -1], [1, 2, 2, 0, 0], [1, 2, 2, 1, 0], [2, 2, 3, 1, 1]]]]> : tensor<3x1x5x5xi9>
    %0 = "FHELinalg.to_signed"(%arg0) : (tensor<1x1x28x28x!FHE.eint<8>>) -> tensor<1x1x28x28x!FHE.esint<8>>
    %1 = "FHELinalg.conv2d"(%0, %cst) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x1x28x28x!FHE.esint<8>>, tensor<3x1x5x5xi9>) -> tensor<1x3x24x24x!FHE.esint<8>>
    %cst_0 = arith.constant dense<"0x0...0"> : tensor<1x3x24x24xindex>
    %cst_1 = arith.constant dense<"0x0...0"> : tensor<3x256xi64>
    %2 = "FHELinalg.apply_mapped_lookup_table"(%1, %cst_1, %cst_0) : (tensor<1x3x24x24x!FHE.esint<8>>, tensor<3x256xi64>, tensor<1x3x24x24xindex>) -> tensor<1x3x24x24x!FHE.eint<5>>
    %cst_2 = arith.constant dense<[[[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[1, 1], [1, 1]]]]> : tensor<3x3x2x2xi6>
    %3 = "FHELinalg.conv2d"(%2, %cst_2) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<2> : tensor<2xi64>} : (tensor<1x3x24x24x!FHE.eint<5>>, tensor<3x3x2x2xi6>) -> tensor<1x3x12x12x!FHE.eint<5>>
    %cst_3 = arith.constant dense<[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]> : tensor<32xi64>
    %4 = "FHELinalg.apply_lookup_table"(%3, %cst_3) : (tensor<1x3x12x12x!FHE.eint<5>>, tensor<32xi64>) -> tensor<1x3x12x12x!FHE.eint<2>>
    %collapsed = tensor.collapse_shape %4 [[0], [1, 2, 3]] : tensor<1x3x12x12x!FHE.eint<2>> into tensor<1x432x!FHE.eint<2>>
    %cst_4 = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>
    %5 = "FHELinalg.apply_lookup_table"(%collapsed, %cst_4) : (tensor<1x432x!FHE.eint<2>>, tensor<4xi64>) -> tensor<1x432x!FHE.eint<9>>
    %cst_5 = arith.constant dense<"0x01...FF03"> : tensor<432x2xi10>
    %6 = "FHELinalg.to_signed"(%5) : (tensor<1x432x!FHE.eint<9>>) -> tensor<1x432x!FHE.esint<9>>
    %7 = "FHELinalg.matmul_eint_int"(%6, %cst_5) : (tensor<1x432x!FHE.esint<9>>, tensor<432x2xi10>) -> tensor<1x2x!FHE.esint<9>>
    return %7 : tensor<1x2x!FHE.esint<9>>
  }
}
--------------------------------------------------------------------------------
1 Like

Awesome thanks a lot !

1 Like

Hello,

i think my built-in concrete neural network is having the same problem (weights & activation n_bits to high) with the following setting:

self.model = NeuralNetClassifier(
            lr = self.learning_rate,
            max_epochs = self.epochs,
            batch_size = 128,
            callbacks = [EpochScoring(scoring='accuracy', name='train_acc', on_train=True)],
            verbose = 0,
            **{
                'module__n_layers': self.n_layers + 1,
                'module__n_w_bits': 6,
                'module__n_a_bits': 6,
                'module__n_accum_bits': 16,
                'module__n_hidden_neurons_multiplier': self.neuron_factor,
                'optimizer__weight_decay': self.weight_decay,
            }
        )

Therefore I’m asking is it also possible to configure the rounding_threshold_bits=6 in the NeuralNetClassifier?

maybe over the configuration hyperparameter in the compile function (see below)? If so how do i set this up?

Help on method compile in module concrete.ml.sklearn.base:

compile(X: 'Data', configuration: 'Optional[Configuration]' = None, artifacts: 'Optional[DebugArtifacts]' = None, show_mlir: 'bool' = False, p_error: 'Optional[float]' = None, global_p_error: 'Optional[float]' = None, verbose: 'bool' = False) -> 'Circuit' method of concrete.ml.sklearn.qnn.NeuralNetClassifier instance
    Compile the model.
    
    Args:
        X (Data): A representative set of input values used for building cryptographic
            parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is
            usually the training data-set or a sub-set of it.
        configuration (Optional[Configuration]): Options to use for compilation. Default
            to None.
        artifacts (Optional[DebugArtifacts]): Artifacts information about the compilation
            process to store for debugging. Default to None.
        show_mlir (bool): Indicate if the MLIR graph should be printed during compilation.
            Default to False.
        p_error (Optional[float]): Probability of error of a single PBS. A p_error value cannot
            be given if a global_p_error value is already set. Default to None, which sets this
            error to a default value.
        global_p_error (Optional[float]): Probability of error of the full circuit. A
            global_p_error value cannot be given if a p_error value is already set. This feature
            is not supported during the FHE simulation mode, meaning the probability is
            currently set to 0. Default to None, which sets this error to a default value.
        verbose (bool): Indicate if compilation information should be printed
            during compilation. Default to False.
    
    Returns:
        Circuit: The compiled Circuit.

Hello again @lstk,
So just so you know, there are already some rounding done in built-in QNNs under the hood, but you’re right we unfortunately do not currently provide a way to let user set their own rounding threshold. We have already noted this feature request internally and we’ll try to expose it in the future, it’s just not a priority for now. If you are willing to add this support yourself as a contribution to our project, we’ll be delighted to help you on that !

In the mean time, I would suggest you to simply lower your n_w_bits and/or n_a_bits (empirically I would say weights can be quantized. with lower bits than activations, if you need to chose). I would not advise to set a n_accum_bits higher than 15 or 16 at this could lead to some compilation overflow. However you could also play around pruning and the n_prune_neurons_percentage parameter.

Hope that helps !

thanks for the reply! will play around then with the parameters :slight_smile:

1 Like