Quantization aware training (QAT)

this is My CNN model for face recognition :

class FaceRecognitionModel_QAT(nn.Module):
    #n_bits (number of bits for quantization).
    def __init__(self, num_classes, n_bits):
        super().__init__()
        a_bits = n_bits
        w_bits = n_bits
        #to quantize the input by placing it at the entry point of the network.
        # we have in this cnn 18179856 MAC
        self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)               
        self.conv1 = qnn.QuantConv2d(3, 32, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.conv2 = qnn.QuantConv2d(32, 64, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.q3 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.conv3 = qnn.QuantConv2d(64, 128, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.q4 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.flatten = nn.Flatten()
        self.q5 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.fc1 = qnn.QuantLinear(128*16*16,512,bias=True,weight_bit_width=w_bits)# Assuming input image size is 128x128 after pooling
        self.q6 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.fc2 = qnn.QuantLinear(512,num_classes,bias=True,weight_bit_width=w_bits)
      # Enable pruning, prepared for training
        self.toggle_pruning(True)

     
    def toggle_pruning(self, enable):
        """Enables or removes pruning."""

        # Maximum number of active neurons (i.e., corresponding weight != 0)
        n_active = 12

        # Go through all the convolution layers
        for layer in (self.conv1, self.conv2, self.conv3):
            s = layer.weight.shape

            # Compute fan-in (number of inputs to a neuron)
            # and fan-out (number of neurons in the layer)
            st = [s[0], np.prod(s[1:])]

            # The number of input neurons (fan-in) is the product of
            # the kernel width x height x inChannels.
            if st[1] > n_active:
                if enable:
                    # This will create a forward hook to create a mask tensor that is multiplied
                    # with the weights during forward. The mask will contain 0s or 1s
                    prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
                else:
                    # When disabling pruning, the mask is multiplied with the weights
                    # and the result is stored in the weights member
                    prune.remove(layer, "weight") 

    def forward(self, x):                           
        x = x.permute(0, 3, 1, 2)  # Transpose the dimensions to [batch_size, channels, height, width]
        x = self.q1(x)                            
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.q2(x)                            
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.q3(x)
        x = self.pool3(F.relu(self.conv3(x)))
        x = self.q4(x)
        x = self.flatten(x)
        x = self.q5(x)
        x = F.relu(self.fc1(x))
        x = self.q6(x)
        x = self.fc2(x)
        return x
  • The first question is about max_pooling when i make QAT Does it have the same architecture as average pooling ?

  • The second question when i use the function compile_brevitas_qat_model after training i got this error “Unsupported: ONNX export of convolution for kernel of unknown shape.”

Hello @ossama_khalyl,

  1. For max-pooling you probably don’t need to add a QuantIdentity since the data is already quantized. But not sure if that answers your question.
  2. We are able to reproduce an error using your code with pruning, and another one without pruning but none are the one you describe.
    Could you please send us a full traceback so that we can look into it? Also a few things that you can explore to help us debug:
    • move the data permutation out of the model
    • remove pruning from your init

Thank you @luis for your response. I have relocated the data permutation outside of the model and removed the pruning from my CNN. Here is my updated CNN code:

class FaceRecognitionModel_QAT(nn.Module):
    #n_bits (number of bits for quantization).
    def __init__(self, num_classes, n_bits):
        super().__init__()
        a_bits = n_bits
        w_bits = n_bits

        #to quantize the input by placing it at the entry point of the network.
        # we have in this cnn 18179856 MAC
        self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)               
        self.conv1 = qnn.QuantConv2d(3, 32, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = qnn.QuantConv2d(32, 64, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = qnn.QuantConv2d(64, 128, 3, stride=1, padding=1, weight_bit_width=w_bits)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = qnn.QuantLinear(128*16*16,512,bias=False,weight_bit_width=w_bits)# Assuming input image size is 128x128 after pooling
        self.fc2 = qnn.QuantLinear(512,num_classes,bias=False,weight_bit_width=w_bits)
      # Enable pruning, prepared for training
        #self.toggle_pruning(True)

     
    """ def toggle_pruning(self, enable):
        

        # Maximum number of active neurons (i.e., corresponding weight != 0)
        n_active = 12

        # Go through all the convolution layers
        for layer in (self.conv1, self.conv2, self.conv3):
            s = layer.weight.shape

            # Compute fan-in (number of inputs to a neuron)
            # and fan-out (number of neurons in the layer)
            st = [s[0], np.prod(s[1:])]

            # The number of input neurons (fan-in) is the product of
            # the kernel width x height x inChannels.
            if st[1] > n_active:
                if enable:
                    # This will create a forward hook to create a mask tensor that is multiplied
                    # with the weights during forward. The mask will contain 0s or 1s
                    prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
                else:
                    # When disabling pruning, the mask is multiplied with the weights
                    # and the result is stored in the weights member
                    prune.remove(layer, "weight") """

    def forward(self, x):                           
        x = self.q1(x)                            
        x = self.pool1(F.relu(self.conv1(x)))                     
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

The complete error message is:

[W shape_type_inference.cpp:1920] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W shape_type_inference.cpp:1920] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 2 WARNING 0 ERROR ========================
2 WARNING were not printed due to the log level.

---------------------------------------------------------------------------
SymbolicValueError                        Traceback (most recent call last)
Cell In[16], line 5
      3 sim_time = []
      4 for idx in range(len(bit_range)):
----> 5     q_module = compile_brevitas_qat_model(nets[idx], x_train_modified)
      7     accum_bits.append(q_module.fhe_circuit.graph.maximum_integer_bit_width())
      9     start_time = time.time()

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/concrete/ml/torch/compile.py:415, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
    413 exporter.onnx_passes.append("eliminate_nop_pad")
    414 exporter.onnx_passes.append("fuse_pad_into_conv")
--> 415 onnx_model = exporter.export(
    416     torch_model,
    417     args=dummy_input_for_tracing,
    418     export_path=str(output_onnx_file_path),
    419     keep_initializers_as_inputs=False,
    420     opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
    421 )
    422 onnx_model = remove_initializer_from_input(onnx_model)
    424 if n_bits is None:

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:159, in ONNXBaseManager.export(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
    149 @classmethod
    150 def export(
    151         cls,
   (...)
    157         disable_warnings=True,
    158         **onnx_export_kwargs):
--> 159     return cls.export_onnx(module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:131, in ONNXBaseManager.export_onnx(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
    129         model_bytes = BytesIO()
    130         export_target = model_bytes
--> 131     torch.onnx.export(module, args, export_target, **onnx_export_kwargs)
    133 # restore the model to previous properties
    134 module.apply(lambda m: _restore_inp_caching_mode(m))

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:506, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
    504     """
--> 506     _export(
    507         model,
    508         args,
    509         f,
    510         export_params,
    511         verbose,
    512         training,
    513         input_names,
    514         output_names,
    515         operator_export_type=operator_export_type,
    516         opset_version=opset_version,
    517         do_constant_folding=do_constant_folding,
    518         dynamic_axes=dynamic_axes,
    519         keep_initializers_as_inputs=keep_initializers_as_inputs,
    520         custom_opsets=custom_opsets,
    521         export_modules_as_functions=export_modules_as_functions,
    522     )

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1548, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
   1545     dynamic_axes = {}
   1546 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1548 graph, params_dict, torch_out = _model_to_graph(
   1549     model,
   1550     args,
   1551     verbose,
   1552     input_names,
   1553     output_names,
   1554     operator_export_type,
   1555     val_do_constant_folding,
   1556     fixed_batch_size=fixed_batch_size,
   1557     training=training,
   1558     dynamic_axes=dynamic_axes,
   1559 )
   1561 # TODO: Don't allocate a in-memory string for the protobuf
   1562 defer_weight_export = (
   1563     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1564 )

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1117, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1114 params_dict = _get_named_param_dict(graph, params)
   1116 try:
-> 1117     graph = _optimize_graph(
   1118         graph,
   1119         operator_export_type,
   1120         _disable_torch_constant_prop=_disable_torch_constant_prop,
   1121         fixed_batch_size=fixed_batch_size,
   1122         params_dict=params_dict,
   1123         dynamic_axes=dynamic_axes,
   1124         input_names=input_names,
   1125         module=module,
   1126     )
   1127 except Exception as e:
   1128     torch.onnx.log("Torch IR graph at exception: ", graph)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:665, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    662     _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
    663 _C._jit_pass_onnx_lint(graph)
--> 665 graph = _C._jit_pass_onnx(graph, operator_export_type)
    666 _C._jit_pass_onnx_lint(graph)
    667 _C._jit_pass_lint(graph)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1891, in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
   1886     if symbolic_fn is not None:
   1887         # TODO Wrap almost identical attrs assignment or comment the difference.
   1888         attrs = {
   1889             k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
   1890         }
-> 1891         return symbolic_fn(graph_context, *inputs, **attrs)
   1893 attrs = {
   1894     k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
   1895     for k in node.attributeNames()
   1896 }
   1897 if namespace == "onnx":
   1898     # Clone node to trigger ONNX shape inference

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py:306, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    300 if len(kwargs) == 1:
    301     assert "_outputs" in kwargs, (
    302         f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
    303         f"'_outputs' key at '**kwargs'. "
    304         f"{FILE_BUG_MSG}"
    305     )
--> 306 return fn(g, *args, **kwargs)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:2451, in _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
   2448     kernel_shape = None
   2450 if kernel_shape is None or any([i is None for i in kernel_shape]):
-> 2451     raise errors.SymbolicValueError(
   2452         "Unsupported: ONNX export of convolution for kernel of unknown shape.",
   2453         input,
   2454     )
   2456 args = [input, weight]
   2457 # ONNX only supports 1D bias

SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.  [Caused by the value 'x defined in (%x : Float(*, *, *, *, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.193, %scale, %zero_point, %bit_width), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx.brevitas::Quant'.] 
    (node defined in /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py(49): symbolic_execution
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/handler.py(114): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py(150): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/nn/quant_layer.py(148): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/tmp/ipykernel_393/1994898414.py(51): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(118): wrapper
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(127): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(1268): _get_trace_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(893): _trace_and_get_graph_from_model
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(989): _create_jit_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(1113): _model_to_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(1548): _export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(506): export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(131): export_onnx
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(159): export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/concrete/ml/torch/compile.py(415): compile_brevitas_qat_model
/tmp/ipykernel_393/382210858.py(5): <module>
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3508): run_code
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3064): _run_cell
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3009): run_cell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py(546): run_cell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py(422): do_execute
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(740): execute_request
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(505): process_one
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/events.py(80): _run
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/base_events.py(1906): _run_once
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/base_events.py(603): run_forever
/home/khalyl/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py(195): start
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py(736): start
/home/khalyl/.local/lib/python3.10/site-packages/traitlets/config/application.py(1043): launch_instance
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/runpy.py(86): _run_code
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/runpy.py(196): _run_module_as_main
)

    Inputs:
        #0: x.193 defined in (%x.193 : Float(1, 3, 128, 128, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu), %conv1.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu), %conv1.bias : Float(32, strides=[1], requires_grad=1, device=cpu), %conv2.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cpu), %conv2.bias : Float(64, strides=[1], requires_grad=1, device=cpu), %conv3.weight : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cpu), %conv3.bias : Float(128, strides=[1], requires_grad=1, device=cpu), %fc1.weight : Float(512, 32768, strides=[32768, 1], requires_grad=1, device=cpu), %fc2.weight : Float(158, 512, strides=[512, 1], requires_grad=1, device=cpu) = prim::Param()
    )  (type 'Tensor')
        #1: scale defined in (%scale : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.5}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
    )  (type 'Tensor')
        #2: zero_point defined in (%zero_point : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
    )  (type 'Tensor')
        #3: bit_width defined in (%bit_width : Float(requires_grad=0, device=cpu) = onnx::Constant[value={2}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py:46:0
    )  (type 'Tensor')
    Outputs:
        #0: x defined in (%x : Float(*, *, *, *, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.193, %scale, %zero_point, %bit_width), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
    )  (type 'Tensor')

can someone answer me or give mz a solution for this error ?

Hello @ossama_khalyl , sorry for the delay.

We weren’t able to replicate the exact issue you are having. Could you please link here the code that you are using for compilation and the list of versions of packages in your python environnment (mainly concrete-ml, concrete-python, torch, onnx and brevitas)

Having the same problem atm. Did you find any solution?

Hello @lstk , as mentioned previously we were not able to replicate the issue.
Could you give us a minimal example to replicate your bug please?
(That means some source code and the versions of Concrete ML, Concrete Python and Brevitas)

Hello luis,

sure. here’s an example:

import sys
import numpy as np
# Sklearn
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
# PyTorch
import torch
from torchsummary import summary
# Concrete ML
import concrete.ml as cml
from concrete.ml.torch.compile import compile_brevitas_qat_model
# Brevitas
import brevitas

# python version
print("Python version: ", sys.version)
# PyTorch version
print("PyTorch version: ", torch.__version__)
# Concrete ML version
print("Concrete ML version: ", cml.__version__)
# Brevitas version
print("Brevitas version: ", brevitas.__version__)

# Load data
x, y = load_digits(return_X_y=True)
# reshape X
x = np.expand_dims(x.reshape((-1, 8, 8)), 1)
# train test split
xtrain_np, xtest_np, ytrain_np, ytest_np = train_test_split(x, y, test_size=0.2, random_state=42)

# convert data to PyTorch tensors
xtrain, xtest = torch.FloatTensor(xtrain_np), torch.FloatTensor(xtest_np)
ytrain, ytest = torch.LongTensor(ytrain_np), torch.LongTensor(ytest_np)

# define model
n_bits = 2

brevitas_model = torch.nn.Sequential(
    brevitas.nn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True),
    brevitas.nn.QuantConv2d(in_channels=1, out_channels=3, kernel_size=3, weight_bit_width=n_bits),
    brevitas.nn.QuantReLU(bit_width=n_bits),
    torch.nn.AvgPool2d(kernel_size=2, stride=2),
    brevitas.nn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True),
    torch.nn.Flatten(),
    brevitas.nn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True),
    brevitas.nn.QuantLinear(in_features=27, out_features=10, bias=True, weight_bit_width=n_bits)
    )

# train model
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(brevitas_model.parameters(), lr=0.001, weight_decay=0.001)

for epoch in range(3):
    optimizer.zero_grad()
    out = brevitas_model(xtrain)
    loss = criterion(out, ytrain)
    loss.backward()
    optimizer.step()

fhe_module = compile_brevitas_qat_model(brevitas_model, xtrain[:100])

Output:

Python version:  3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
PyTorch version:  2.1.2
Concrete ML version:  1.3.0
Brevitas version:  0.8.0
[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)

Error Message:

---------------------------------------------------------------------------
SymbolicValueError                        Traceback (most recent call last)
Cell In[1], line 60
     57     loss.backward()
     58     optimizer.step()
---> 60 fhe_module = compile_brevitas_qat_model(brevitas_model, xtrain[:100], n_bits=n_bits)

File ~/.local/lib/python3.10/site-packages/concrete/ml/torch/compile.py:443, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose, inputs_encryption_status)
    430 # Here we add a "eliminate_nop_pad" optimization step for onnxoptimizer
    431 # https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/eliminate_nop_pad.h#L5
    432 # It deletes 0-values padding.
   (...)
    436 # In the export function, the `args` parameter is used instead of the `input_shape` one in
    437 # order to be able to handle multi-inputs models
    438 exporter.onnx_passes += [
    439     "eliminate_nop_pad",
    440     "fuse_pad_into_conv",
    441     "fuse_matmul_add_bias_into_gemm",
    442 ]
--> 443 onnx_model = exporter.export(
    444     torch_model,
    445     args=dummy_input_for_tracing,
    446     export_path=str(output_onnx_file_path),
    447     keep_initializers_as_inputs=False,
    448     opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
    449 )
    450 onnx_model = remove_initializer_from_input(onnx_model)
    452 if n_bits is None:

File ~/.local/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:159, in ONNXBaseManager.export(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
    149 @classmethod
    150 def export(
    151         cls,
   (...)
    157         disable_warnings=True,
    158         **onnx_export_kwargs):
--> 159     return cls.export_onnx(module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)

File ~/.local/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:131, in ONNXBaseManager.export_onnx(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
    129         model_bytes = BytesIO()
    130         export_target = model_bytes
--> 131     torch.onnx.export(module, args, export_target, **onnx_export_kwargs)
    133 # restore the model to previous properties
    134 module.apply(lambda m: _restore_inp_caching_mode(m))

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py:516, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    189 @_beartype.beartype
    190 def export(
    191     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    208     autograd_inlining: Optional[bool] = True,
    209 ) -> None:
    210     r"""Exports a model into ONNX format.
    211 
    212     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    513             All errors are subclasses of :class:`errors.OnnxExporterError`.
    514     """
--> 516     _export(
    517         model,
    518         args,
    519         f,
    520         export_params,
    521         verbose,
    522         training,
    523         input_names,
    524         output_names,
    525         operator_export_type=operator_export_type,
    526         opset_version=opset_version,
    527         do_constant_folding=do_constant_folding,
    528         dynamic_axes=dynamic_axes,
    529         keep_initializers_as_inputs=keep_initializers_as_inputs,
    530         custom_opsets=custom_opsets,
    531         export_modules_as_functions=export_modules_as_functions,
    532         autograd_inlining=autograd_inlining,
    533     )

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py:1596, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1593     dynamic_axes = {}
   1594 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1596 graph, params_dict, torch_out = _model_to_graph(
   1597     model,
   1598     args,
   1599     verbose,
   1600     input_names,
   1601     output_names,
   1602     operator_export_type,
   1603     val_do_constant_folding,
   1604     fixed_batch_size=fixed_batch_size,
   1605     training=training,
   1606     dynamic_axes=dynamic_axes,
   1607 )
   1609 # TODO: Don't allocate a in-memory string for the protobuf
   1610 defer_weight_export = (
   1611     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1612 )

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py:1139, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1136 params_dict = _get_named_param_dict(graph, params)
   1138 try:
-> 1139     graph = _optimize_graph(
   1140         graph,
   1141         operator_export_type,
   1142         _disable_torch_constant_prop=_disable_torch_constant_prop,
   1143         fixed_batch_size=fixed_batch_size,
   1144         params_dict=params_dict,
   1145         dynamic_axes=dynamic_axes,
   1146         input_names=input_names,
   1147         module=module,
   1148     )
   1149 except Exception as e:
   1150     torch.onnx.log("Torch IR graph at exception: ", graph)

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py:677, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    674     _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
    675 _C._jit_pass_onnx_lint(graph)
--> 677 graph = _C._jit_pass_onnx(graph, operator_export_type)
    678 _C._jit_pass_onnx_lint(graph)
    679 _C._jit_pass_lint(graph)

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py:1940, in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
   1935     if symbolic_fn is not None:
   1936         # TODO Wrap almost identical attrs assignment or comment the difference.
   1937         attrs = {
   1938             k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
   1939         }
-> 1940         return symbolic_fn(graph_context, *inputs, **attrs)
   1942 attrs = {
   1943     k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
   1944     for k in node.attributeNames()
   1945 }
   1946 if namespace == "onnx":
   1947     # Clone node to trigger ONNX shape inference

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py:306, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    300 if len(kwargs) == 1:
    301     assert "_outputs" in kwargs, (
    302         f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
    303         f"'_outputs' key at '**kwargs'. "
    304         f"{FILE_BUG_MSG}"
    305     )
--> 306 return fn(g, *args, **kwargs)

File ~/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:2519, in _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
   2516     kernel_shape = None
   2518 if kernel_shape is None or any(i is None for i in kernel_shape):
-> 2519     raise errors.SymbolicValueError(
   2520         "Unsupported: ONNX export of convolution for kernel of unknown shape.",
   2521         input,
   2522     )
   2524 args = [input, weight]
   2525 # ONNX only supports 1D bias

SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.  [Caused by the value 'x defined in (%x : Float(*, *, *, *, strides=[64, 64, 8, 1], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.173, %scale, %zero_point, %bit_width), scope: torch.nn.modules.container.Sequential::/brevitas.nn.quant_activation.QuantIdentity::0/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/autograd/function.py:539:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx.brevitas::Quant'.] 
    (node defined in /home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/autograd/function.py(539): apply
/home/lukas/.local/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py(49): symbolic_execution
/home/lukas/.local/lib/python3.10/site-packages/brevitas/export/onnx/handler.py(114): forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/lukas/.local/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py(150): forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/lukas/.local/lib/python3.10/site-packages/brevitas/nn/quant_layer.py(148): forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/container.py(215): forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py(124): wrapper
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py(133): forward
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py(1285): _get_trace_graph
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py(915): _trace_and_get_graph_from_model
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py(1011): _create_jit_graph
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py(1135): _model_to_graph
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py(1596): _export
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py(516): export
/home/lukas/.local/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(131): export_onnx
/home/lukas/.local/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(159): export
/home/lukas/.local/lib/python3.10/site-packages/concrete/ml/torch/compile.py(443): compile_brevitas_qat_model
/tmp/ipykernel_1601331/456282306.py(60): <module>
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3553): run_code
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3493): run_ast_nodes
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3311): run_cell_async
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3106): _run_cell
/home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3051): run_cell
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py(426): do_execute
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(758): execute_request
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(418): dispatch_shell
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(513): process_one
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(524): dispatch_queue
/home/lukas/anaconda3/envs/test/lib/python3.10/asyncio/events.py(80): _run
/home/lukas/anaconda3/envs/test/lib/python3.10/asyncio/base_events.py(1909): _run_once
/home/lukas/anaconda3/envs/test/lib/python3.10/asyncio/base_events.py(603): run_forever
/home/lukas/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py(195): start
/home/lukas/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py(737): start
/home/lukas/.local/lib/python3.10/site-packages/traitlets/config/application.py(1053): launch_instance
/home/lukas/.local/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
/home/lukas/anaconda3/envs/test/lib/python3.10/runpy.py(86): _run_code
/home/lukas/anaconda3/envs/test/lib/python3.10/runpy.py(196): _run_module_as_main
)

    Inputs:
        #0: x.173 defined in (%x.173 : Float(1, 1, 8, 8, strides=[64, 64, 8, 1], requires_grad=0, device=cpu), %0.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value : Float(requires_grad=0, device=cpu), %1.weight : Float(3, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=1, device=cpu), %1.bias : Float(3, strides=[1], requires_grad=1, device=cpu), %2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value : Float(requires_grad=0, device=cpu), %4.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value : Float(requires_grad=0, device=cpu), %6.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value : Float(requires_grad=0, device=cpu), %7.weight : Float(10, 27, strides=[27, 1], requires_grad=1, device=cpu), %7.bias : Float(10, strides=[1], requires_grad=1, device=cpu) = prim::Param()
    )  (type 'Tensor')
        #1: scale defined in (%scale : Float(requires_grad=0, device=cpu) = onnx::Constant[value={8}](), scope: torch.nn.modules.container.Sequential::/brevitas.nn.quant_activation.QuantIdentity::0/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/autograd/function.py:539:0
    )  (type 'Tensor')
        #2: zero_point defined in (%zero_point : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0}](), scope: torch.nn.modules.container.Sequential::/brevitas.nn.quant_activation.QuantIdentity::0/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/autograd/function.py:539:0
    )  (type 'Tensor')
        #3: bit_width defined in (%bit_width : Float(requires_grad=0, device=cpu) = onnx::Constant[value={2}](), scope: torch.nn.modules.container.Sequential::/brevitas.nn.quant_activation.QuantIdentity::0/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/lukas/.local/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py:46:0
    )  (type 'Tensor')
    Outputs:
        #0: x defined in (%x : Float(*, *, *, *, strides=[64, 64, 8, 1], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.173, %scale, %zero_point, %bit_width), scope: torch.nn.modules.container.Sequential::/brevitas.nn.quant_activation.QuantIdentity::0/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/lukas/anaconda3/envs/test/lib/python3.10/site-packages/torch/autograd/function.py:539:0
    )  (type 'Tensor')```

Hello @lstk , thanks for the information!
The error you are facing is due to the version of pytorch.
You should downgrade to a 1.13.1 version of torch for now.

We should support torch 2.x in the future but for the meantime this is the workaround :slightly_smiling_face:

thanks! working now :slight_smile:

1 Like