API for faster key generation

Are there any new api’s for faster key generation…I found this code but it takes long time

Generate keys first, this may take some time (up to 30min)

t = time.time()
q_module_fhe.fhe_circuit.keygen()
print(f"Keygen time: {time.time()-t:.2f}s")

Can we also do the low level operations ( key gen ,encrypt ,decrypt,computing on cipher text) on concrete ML?

Hello @Rish ,
How long does your keygen take ? Please be aware that the comment seems a bit outdated as we ran this key generation in around 2-3 minutes on our machines.

Besides, models like Convolutional Neural Networks (I believe that’s where you took this comment) have complex structures. This means that the underlying FHE circuit gets complex as well, and that the key generation will most probably take a longer time than “simpler” models like linear or tree-based models ! Currently, we do not provide any other key generation algorithm to make it faster. However, we are always working on improving such key features and you could expect some improvements in future releases.

As for your question on “low level operations”, could you maybe be a bit more specific ? I believe you might be more interested in our Concrete library, where you can define, compile and execute any operation or algorithm in FHE.

However, if you are talking about having the ability to call different steps for a Concrete-ML model, I recommend you to have a look at this documentation page. You will see how to easily call keygen, encrypt, decrypt and run in separate calls. Furthermore, if your goal is to deploy a model in a client-server setting, here’s how to use our dedicated API.

Hope this helps !

Yes i took it from the CNN on MNIST ,The key gen took (Keygen time: 1129.50s), I saw some other notebooks but it relatively takes less time , I see that key generation is proportional to the complexity of quantized module ),thanks for pointing out

and yes I want to execute the basic operations on concrete-ML rather than just compiling and evaluating model at a go,I believe I can use the defs of concrete along with the concrete ML models to build things from scratch

say I want to break down below code (quantization aware training in MNIST)to individual function of FHE simulation,FHE computation and before that generate encrypt after key gen all these sequentially ,can i realise it all?

Test in the FHE simulation and in real FHE computation

accuracy = {}
current_index = 3

for use_simulation, use_full_dataset in [(True, True), (True, False), (False, False)]:
test_data_length = test_data_length_full if use_full_dataset else test_data_length_reduced

correct_fhe, test_data_shape_0, max_bit_width = compile_and_test(
    model.cpu(),
    use_simulation,
    test_data,
    test_data_length,
    test_target,
    show_mlir,
    current_index,
)

current_index += 2
current_accuracy = correct_fhe / test_data_shape_0

print(
    f"Accuracy in {'Simulation' if use_simulation else 'FHE'} with length {test_data_length}: "
    f"{correct_fhe}/{test_data_shape_0} = "
    f"{current_accuracy:.4f}, in {max_bit_width}-bits"
)

if (use_simulation, use_full_dataset) == (True, True):
    accuracy["FHE Simulation full"] = current_accuracy
elif (use_simulation, use_full_dataset) == (True, False):
    accuracy["FHE simulation short"] = current_accuracy
else:
    assert (use_simulation, use_full_dataset) == (False, False)
    accuracy["FHE short"] = current_accuracy

If you look at the compile_and_test function defined in the same notebook, you should be able to spot the different steps (compile, keygen, forward) :slightly_smiling_face: If you also want to to separate the steps found in the forward (quantized, encrypt, run, decrypt, dequantize), I believe you’ll have to take a look at the source code (forward, and more specifically its FHE part _fhe_forward).

I wanted to know one more thing can you give a algorithmic description of FHE simulation, I hope i am not bothering you again and again with questions

No worries, we are happy to answer !

Simulation is done by Concrete (the FHE compiler we use in Concrete ML), so here is an answer @umutsahin recently wrote on the subject !

More precisely, Concrete ML currently simulating using the first version, but you should expect this to change in the near future :wink:

Else, use the FHE execution method

        else:
            predict_method = self.fhe_circuit.encrypt_run_decrypt

        # Execute the forward pass in FHE or with simulation
        q_result = predict_method(*q_input)

        results_cnp_circuit_list.append(q_result)

    results_cnp_circuit = numpy.concatenate(results_cnp_circuit_list, axis=0)

    return results_cnp_circuit

can I somehow decompose this encrypt_run_decrypt ,into indivdual steps of encryption run and decryption (please point me towards the source code ) and then simulate this process within client server setting clients send (encrypted inputs ,encrypted labels) ,the compiled FHE circuits eveluates it( server part) and the client decrypts it
please help me on this

I believe my previous answer points to the source code you are asking for :slightly_smiling_face:

Probably i am not able to spot , the forward method has fhe_forward function which executes the fhe using fhe_circuit(the compiled model) and it calls the encrypt_run_decrypt method, but i found all these processes (encrypt ,run ,decrypt ) separately listed in the circuit class as
def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
) → Optional[Union[Value, Tuple[Optional[Value], …]]]:
“”"
Encrypt argument(s) to for evaluation.

    Args:
        *args (Optional[Union[int, numpy.ndarray, List]]):
            argument(s) for evaluation

    Returns:
        Optional[Union[Value, Tuple[Optional[Value], ...]]]:
            encrypted argument(s) for evaluation
    """

    if not hasattr(self, "client"):  # pragma: no cover
        self.enable_fhe_execution()

    return self.client.encrypt(*args)

Here there is one more layer i need to uncover enable_fhe_execution?

def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
) → Optional[Union[Value, Tuple[Optional[Value], …]]]:
“”"
Encrypt argument(s) to for evaluation.

    Args:
        *args (Optional[Union[int, np.ndarray, List]]):
            argument(s) for evaluation

    Returns:
        Optional[Union[Value, Tuple[Optional[Value], ...]]]:
            encrypted argument(s) for evaluation
    """

    ordered_sanitized_args = validate_input_args(self.specs, *args)

    self.keygen(force=False)
    keyset = self.keys._keyset  # pylint: disable=protected-access

    exporter = ValueExporter.new(keyset, self.specs.client_parameters)
    exported = [
        None
        if arg is None
        else Value(
            exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
            if isinstance(arg, np.ndarray) and arg.shape != ()
            else exporter.export_scalar(position, int(arg))
        )
        for position, arg in enumerate(ordered_sanitized_args)
    ]

    return tuple(exported) if len(exported) != 1 else exported[0]

This is an excerpt from client.py
here encryption is done using value exporter,
How do i reconcile the two processes?

The Concrete documentation gives exactly such an example that shows deployment in a client/server setting. The client encrypts and the server executes on this encrypted data. I would suggest you execute the code listed on that page and see if it fits your need.

1 Like

I also invite you to check some of the links I provided in my first answer :

However, if you are talking about having the ability to call different steps for a Concrete-ML model, I recommend you to have a look at this documentation page. You will see how to easily call keygen, encrypt, decrypt and run in separate calls. Furthermore, if your goal is to deploy a model in a client-server setting, here’s how to use our dedicated API .

In the first link, we give an example where encrypt, run and decrypt are called separately on a built-in model (LogisticRegression). You could do the same with a quantized module, but just be aware that circuits currently only handles batches of 1. Furthermore, the second link gives your some additional information for client-server/deployment settings. Notably, we give an example of such use as well as how to deploy a model on a AWS instance.

1 Like

Tried to create a smaller qat model with mnist (throws off key error)
net=SmallerCNN(n_classes=10, n_bits=6)

model_dev = compile_brevitas_qat_model( net,x_train)
KeyError Traceback (most recent call last)
in <cell line: 2>()
1 net=SmallerCNN(n_classes=10, n_bits=7)
----> 2 model_dev = compile_brevitas_qat_model( net,x_train)

26 frames
/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
413 exporter.onnx_passes.append(“eliminate_nop_pad”)
414 exporter.onnx_passes.append(“fuse_pad_into_conv”)
→ 415 onnx_model = exporter.export(
416 torch_model,
417 args=dummy_input_for_tracing,

/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/manager.py in export(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
157 disable_warnings=True,
158 **onnx_export_kwargs):
→ 159 return cls.export_onnx(module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)

/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/manager.py in export_onnx(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
129 model_bytes = BytesIO()
130 export_target = model_bytes
→ 131 torch.onnx.export(module, args, export_target, **onnx_export_kwargs)
132
133 # restore the model to previous properties

/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
502 “”"
503
→ 504 _export(
505 model,
506 args,

/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
1527 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1528
→ 1529 graph, params_dict, torch_out = _model_to_graph(
1530 model,
1531 args,

/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1109
1110 model = _pre_trace_quant_model(model, args)
→ 1111 graph, params, torch_out, module = _create_jit_graph(model, args)
1112 params_dict = _get_named_param_dict(graph, params)
1113

/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _create_jit_graph(model, args)
985 return graph, params, torch_out, None
986
→ 987 graph, torch_out = _trace_and_get_graph_from_model(model, args)
988 _C._jit_pass_onnx_lint(graph)
989 state_dict = torch.jit._unique_state_dict(model)

/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args)
889 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
890 torch.set_autocast_cache_enabled(False)
→ 891 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
892 model,
893 args,

/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1182 if not isinstance(args, tuple):
1183 args = (args,)
→ 1184 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1185 return outs

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,

/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in forward(self, *args)
125 return tuple(out_vars)
126
→ 127 graph, out = torch._C._create_graph_by_tracing(
128 wrapper,
129 in_vars + module_state,

/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in wrapper(*args)
116 if self._return_inputs_states:
117 inputs_states.append(_unflatten(in_args, in_desc))
→ 118 outs.append(self.inner(*trace_inputs))
119 if self._return_inputs_states:
120 inputs_states[0] = (inputs_states[0], trace_inputs)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:

in forward(self, x)
57 x = torch.relu(x)
58 x = self.q2(x)
—> 59 x = self.conv2(x)
60 x = torch.relu(x)
61

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1210 input = bw_hook.setup_input_hook(input)
1211
→ 1212 result = forward_call(input, **kwargs)
1213 if _global_forward_hooks or self._forward_hooks:
1214 for hook in (
_global_forward_hooks.values(), *self._forward_hooks.values()):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:

/usr/local/lib/python3.10/dist-packages/brevitas/nn/quant_conv.py in forward(self, input)
187
188 def forward(self, input: Union[Tensor, QuantTensor]) → Union[Tensor, QuantTensor]:
→ 189 return self.forward_impl(input)
190
191 def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):

/usr/local/lib/python3.10/dist-packages/brevitas/nn/quant_layer.py in forward_impl(self, inp)
315
316 quant_input = self.input_quant(inp)
→ 317 quant_weight = self.quant_weight()
318
319 if quant_input.bit_width is not None:

/usr/local/lib/python3.10/dist-packages/brevitas/nn/mixin/parameter.py in quant_weight(self)
53
54 def quant_weight(self):
—> 55 return self.weight_quant(self.weight)
56
57 def int_weight(self, float_datatype=False):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:

/usr/local/lib/python3.10/dist-packages/brevitas/proxy/parameter_quant.py in forward(self, x)
84 if self.is_quant_enabled:
85 impl = self.export_handler if self.export_mode else self.tensor_quant
—> 86 out, scale, zero_point, bit_width = impl(x)
87 return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
88 else: # quantization disabled

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:

/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/handler.py in forward(self, inp, *args, **kwargs)
112 if self.export_debug_name is not None and self.debug_input:
113 inp = debug_fn(inp, ‘.input’)
→ 114 out = self.symbolic_execution(inp, *args, **kwargs)
115 if self.export_debug_name is not None and self.debug_output:
116 if isinstance(out, Tensor):

/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/qonnx/handler.py in symbolic_execution(self, x)
67
68 def symbolic_execution(self, x: Tensor):
—> 69 quant_weight = self.quant_weights[x.data_ptr()]
70 return super().symbolic_execution(quant_weight)
71

KeyError: 97488528374272

Hello,
If I remember correctly, this issue is related to our pruning feature. Did you run prune() or something similar on your model ? More generally, would it be possible to know the code you executed ? I’m mostly interested in how you defined the model and what steps did you run until this issue rose. Thanks !

class SmallerCNN(nn.Module):
“”“A smaller CNN to classify the sklearn digits data-set.”“”

def __init__(self, n_classes, n_bits) -> None:
    super().__init__()

    a_bits = n_bits
    w_bits = n_bits


    self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
    self.conv1 = qnn.QuantConv2d(1, 4, 3, stride=1, padding=0, weight_bit_width=w_bits)
    self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
    self.conv2 = qnn.QuantConv2d(4, 8, 2, stride=2, padding=0, weight_bit_width=w_bits)

    self.fc1 = qnn.QuantLinear(
        8 * 3 * 3,
        n_classes,
        bias=True,
        weight_bit_width=w_bits,
    )

    self.toggle_pruning(True)


def toggle_pruning(self, enable):
    """Enables or removes pruning."""

    # Maximum number of active neurons (i.e., corresponding weight != 0)
    n_active = 6

    # Go through all the convolution layers
    for layer in (self.conv1, self.conv2):
        s = layer.weight.shape

        # Compute fan-in (number of inputs to a neuron)
        # and fan-out (number of neurons in the layer)
        st = [s[0], np.prod(s[1:])]

        # The number of input neurons (fan-in) is the product of
        # the kernel width x height x inChannels.
        if st[1] > n_active:
            if enable:
                # This will create a forward hook to create a mask tensor that is multiplied
                # with the weights during forward. The mask will contain 0s or 1s
                prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
            else:
                # When disabling pruning, the mask is multiplied with the weights
                # and the result is stored in the weights member
                prune.remove(layer, "weight")
def forward(self, x):
    """Run inference on the smaller CNN, apply the decision layer on the reshaped conv output."""

    x = self.q1(x)
    x = self.conv1(x)
    x = torch.relu(x)
    x = self.q2(x)
    x = self.conv2(x)
    x = torch.relu(x)

    # Flatten the tensor before passing it to the fully connected layer
    x = x.view(x.size(0), -1)

    x = self.fc1(x)
    return x

I am not sure if there is an issue with pruning ,i tried without it as well throws an asserion error( I could well train on this model with similar accuracy but not able to compile to brevitas qat
model)

Ok I see thanks, but then what do you do with this model ? How do you train it ? Are you following our CNN notebook example steps ?

More specifically, since you enable pruning in your model’s initialization, you’ll have to make it permanent by calling .toggle_pruning(False) after training it so that pruned weights are properly “removed”. You can check the notebook’s “Train the CNN” section for example.

Thankyou, as u said I was following a similar approach(training with only a specific bit) ,I will try out your suggestions

1 Like

net=SmallerCNN(n_classes=10, n_bits=6)

net.toggle_pruning(False)

model_dev = compile_brevitas_qat_model( net,x_train)

AssertionError Traceback (most recent call last)
in <cell line: 3>()
1 net=SmallerCNN(n_classes=10, n_bits=6)
2 net.toggle_pruning(False)
----> 3 model_dev = compile_brevitas_qat_model( net,x_train)

7 frames
/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
455
456 # Compile using the ONNX conversion flow, in QAT mode
→ 457 q_module = compile_onnx_model(
458 onnx_model,
459 torch_inputset,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_onnx_model(onnx_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose)
307 )
308
→ 309 return _compile_torch_or_onnx_model(
310 onnx_model,
311 torch_inputset,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose)
149
150 # Build the quantized module
→ 151 quantized_module = build_quantized_module(
152 model=model,
153 torch_inputset=inputset_as_numpy_tuple,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits)
97
98 # Build the quantized module
—> 99 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
100
101 return quantized_module

/usr/local/lib/python3.10/dist-packages/concrete/ml/quantization/post_training.py in quantize_module(self, *calibration_data)
591 self._quantize_params()
592
→ 593 self._quantize_layers(*calibration_data)
594
595 # Create quantized module from self.quant_layers_dict

/usr/local/lib/python3.10/dist-packages/concrete/ml/quantization/post_training.py in _quantize_layers(self, *input_calibration_data)
463
464 # For mypy
→ 465 assert_true(
466 all(val is None or isinstance(val, numpy.ndarray) for val in curr_calibration_data)
467 )

/usr/local/lib/python3.10/dist-packages/concrete/ml/common/debugging/custom_assert.py in assert_true(condition, on_error_msg, error_type)
38
39 “”"
—> 40 _custom_assert(condition, on_error_msg, error_type)
41
42

/usr/local/lib/python3.10/dist-packages/concrete/ml/common/debugging/custom_assert.py in _custom_assert(condition, on_error_msg, error_type)
23
24 if not condition:
—> 25 raise error_type(on_error_msg)
26
27

AssertionError:

I had one more doubt how do i incorporate compile attribute in my class so that after training say model,i can directly run model.compile,so that i have a compiled trained model

Hello @Rish ,
Just to make sure, the AssertionError: is empty ? There’s no additional message attached to it ? In any case, please be sure to train your model before compiling it, else this might create some errors.

As for your question, I’m not sure I understand it. If you want to have a .compile() method on your net, then simply call compile_brevitas_qat_model on self and x_train in it !

net=SmallerCNN(n_classes=10, n_bits=3)

net.toggle_pruning(False)

net.compile(x_train)

Below is the traceback of the empty assertion error

AssertionError Traceback (most recent call last)
in <cell line: 3>()
1 net=SmallerCNN(n_classes=10, n_bits=3)
2 net.toggle_pruning(False)
----> 3 net.compile(x_train)

8 frames
in compile(self, x_train, configuration, show_mlir)
74 “”"
75 # Perform the compilation using the provided training data and configuration
—> 76 self.q_module = compile_brevitas_qat_model(self, x_train, configuration=configuration, show_mlir=show_mlir)
77
78 # we can use this

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
455
456 # Compile using the ONNX conversion flow, in QAT mode
→ 457 q_module = compile_onnx_model(
458 onnx_model,
459 torch_inputset,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_onnx_model(onnx_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose)
307 )
308
→ 309 return _compile_torch_or_onnx_model(
310 onnx_model,
311 torch_inputset,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose)
149
150 # Build the quantized module
→ 151 quantized_module = build_quantized_module(
152 model=model,
153 torch_inputset=inputset_as_numpy_tuple,

/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits)
97
98 # Build the quantized module
—> 99 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
100
101 return quantized_module

/usr/local/lib/python3.10/dist-packages/concrete/ml/quantization/post_training.py in quantize_module(self, *calibration_data)
591 self._quantize_params()
592
→ 593 self._quantize_layers(*calibration_data)
594
595 # Create quantized module from self.quant_layers_dict

/usr/local/lib/python3.10/dist-packages/concrete/ml/quantization/post_training.py in _quantize_layers(self, *input_calibration_data)
463
464 # For mypy
→ 465 assert_true(
466 all(val is None or isinstance(val, numpy.ndarray) for val in curr_calibration_data)
467 )

/usr/local/lib/python3.10/dist-packages/concrete/ml/common/debugging/custom_assert.py in assert_true(condition, on_error_msg, error_type)
38
39 “”"
—> 40 _custom_assert(condition, on_error_msg, error_type)
41
42

/usr/local/lib/python3.10/dist-packages/concrete/ml/common/debugging/custom_assert.py in _custom_assert(condition, on_error_msg, error_type)
23
24 if not condition:
—> 25 raise error_type(on_error_msg)
26
27

AssertionError:

class TinyCNN(nn.Module):
def init(self, n_classes, n_bits) → None:
super().init()

    a_bits = n_bits
    w_bits = n_bits

    self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
    self.conv1 = qnn.QuantConv2d(1, 4, 3, stride=1, padding=0, weight_bit_width=w_bits)
    self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
    self.conv2 = qnn.QuantConv2d(4, 8, 2, stride=2, padding=0, weight_bit_width=w_bits)

    self.fc1 = qnn.QuantLinear(
        8 * 3 * 3,
        n_classes,
        bias=True,
        weight_bit_width=w_bits,
    )

def forward(self, x):
  

    x = self.q1(x)
    x = self.conv1(x)
    x = torch.relu(x)
    x = self.q2(x)
    x = self.conv2(x)
    x = torch.relu(x)

    x = x.view(x.size(0), -1)

    x = self.fc1(x)
    return x

could reproduce the error indeed for the model defined without pruning , it works quite well on train and test dataloaders