Quantized Model compilation fails with weird issue

Hi. I am trying to compile a quantized brevitas model. However, I am thrown a weird issue while compilation. Here is the complete stack trace -

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[16], line 2
      1 # with torch.no_grad():
----> 2 quantized_module = compile_brevitas_qat_model(
      3     transformer, 
      4     img_tensor, 
      5     rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
      6 )

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:560, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose, inputs_encryption_status, reduce_sum_copy)
    554 assert_true(
    555     n_bits is None or isinstance(n_bits, (int, dict)),
    556     "The n_bits parameter must be either a dictionary, an integer or None",
    557 )
    559 # Compile using the ONNX conversion flow, in QAT mode
--> 560 q_module = compile_onnx_model(
    561     onnx_model,
    562     torch_inputset,
    563     n_bits=n_bits,
    564     import_qat=True,
    565     artifacts=artifacts,
    566     show_mlir=show_mlir,
    567     rounding_threshold_bits=rounding_threshold_bits,
    568     configuration=configuration,
    569     p_error=p_error,
    570     global_p_error=global_p_error,
    571     verbose=verbose,
    572     inputs_encryption_status=inputs_encryption_status,
    573     reduce_sum_copy=reduce_sum_copy,
    574 )
    576 # Remove the tempfile if we used one
    577 if use_tempfile:

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:409, in compile_onnx_model(onnx_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
    402 onnx_model_opset_version = get_onnx_opset_version(onnx_model)
    403 assert_true(
    404     onnx_model_opset_version == OPSET_VERSION_FOR_ONNX_EXPORT,
    405     f"ONNX version must be {OPSET_VERSION_FOR_ONNX_EXPORT} "
    406     f"but it is {onnx_model_opset_version}",
    407 )
--> 409 return _compile_torch_or_onnx_model(
    410     onnx_model,
    411     torch_inputset,
    412     import_qat,
    413     configuration=configuration,
    414     artifacts=artifacts,
    415     show_mlir=show_mlir,
    416     n_bits=n_bits,
    417     rounding_threshold_bits=rounding_threshold_bits,
    418     p_error=p_error,
    419     global_p_error=global_p_error,
    420     verbose=verbose,
    421     inputs_encryption_status=inputs_encryption_status,
    422     reduce_sum_copy=reduce_sum_copy,
    423 )

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:214, in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy, composition_mapping)
    208     raise ValueError(
    209         "Composition must be enabled in 'configuration' in order to trigger a re-quantization "
    210         "step on the circuit's outputs."
    211     )
    213 # Build the quantized module
--> 214 quantized_module = build_quantized_module(
    215     model=model,
    216     torch_inputset=inputset_as_numpy_tuple,
    217     import_qat=import_qat,
    218     n_bits=n_bits,
    219     rounding_threshold_bits=rounding_threshold_bits,
    220     reduce_sum_copy=reduce_sum_copy,
    221 )
    223 # Check that p_error or global_p_error is not set in both the configuration and in the direct
    224 # parameters
    225 check_there_is_no_p_error_options_in_configuration(configuration)

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:127, in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits, reduce_sum_copy)
    121 post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
    123 # Build the quantized module
    124 # FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
    125 # only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
    126 # on the inputset we sent in the NumpyModule.
--> 127 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
    129 # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
    130 if reduce_sum_copy:

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:702, in ONNXConverter.quantize_module(self, *calibration_data)
    699 self._quantize_layers(*calibration_data)
    701 # Create quantized module from self.quant_layers_dict
--> 702 quantized_module = QuantizedModule(
    703     ordered_module_input_names=(
    704         graph_input.name for graph_input in self.numpy_model.onnx_model.graph.input
    705     ),
    706     ordered_module_output_names=(
    707         graph_output.name for graph_output in self.numpy_model.onnx_model.graph.output
    708     ),
    709     quant_layers_dict=self.quant_ops_dict,
    710     onnx_model=self.numpy_model.onnx_model,
    711 )
    713 adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
    714 # Apply the round PBS optimization if possible

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:138, in QuantizedModule.__init__(self, ordered_module_input_names, ordered_module_output_names, quant_layers_dict, onnx_model)
    136 # Initialize output quantizers based on quant_layers_dict
    137 if self.quant_layers_dict:
--> 138     self.output_quantizers = self._set_output_quantizers()
    139 else:
    140     self.output_quantizers = []

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:280, in QuantizedModule._set_output_quantizers(self)
    274 def _set_output_quantizers(self) -> List[UniformQuantizer]:
    275     """Get the output quantizers.
    276 
    277     Returns:
    278         List[UniformQuantizer]: List of output quantizers.
    279     """
--> 280     output_layers = list(
    281         self.quant_layers_dict[output_name][1]
    282         for output_name in self.ordered_module_output_names
    283     )
    284     output_quantizers = list(
    285         QuantizedArray(
    286             output_layer.n_bits,
   (...)
    292         for output_layer in output_layers
    293     )
    294     return output_quantizers

File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:281, in <genexpr>(.0)
    274 def _set_output_quantizers(self) -> List[UniformQuantizer]:
    275     """Get the output quantizers.
    276 
    277     Returns:
    278         List[UniformQuantizer]: List of output quantizers.
    279     """
    280     output_layers = list(
--> 281         self.quant_layers_dict[output_name][1]
    282         for output_name in self.ordered_module_output_names
    283     )
    284     output_quantizers = list(
    285         QuantizedArray(
    286             output_layer.n_bits,
   (...)
    292         for output_layer in output_layers
    293     )
    294     return output_quantizers

KeyError: '222'

I am trying to compile the following model -

class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        self.model = nn.Sequential(
            ConvBlock(3, 8, kernel_size=3, stride=1),
            ConvBlock(8, 16, kernel_size=3, stride=2),
            ConvBlock(16, 32, kernel_size=3, stride=2),
            ResidualBlock(32),
            ResidualBlock(32),
            ResidualBlock(32),
            ResidualBlock(32),
            ResidualBlock(32),
            ConvBlock(32, 16, kernel_size=3),
            ConvBlock(16, 8, kernel_size=3),
            ConvBlock(8, 3, kernel_size=3, stride=1, normalize=False, relu=False),
        )

    def forward(self, x):
        return self.model(x)

""" Components of Transformer Net """
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=True),
            ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=False),
        )

    def forward(self, x):
        return self.block(x) + x

class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, normalize=True, relu=True):
        super(ConvBlock, self).__init__()
        self.kernel_size = kernel_size
        self.block = nn.Sequential(
            qnn.QuantIdentity(bit_width=5, return_quant_tensor=True),
            qnn.QuantConv2d(in_channels, out_channels, kernel_size, stride, padding=self.kernel_size//2, return_quant_tensor=True)
        )
        self.norm = nn.BatchNorm2d(out_channels, affine=True) if normalize else None
        self.relu = relu
        
    def forward(self, x):
        x = self.block(x)
        if self.norm is not None:
            x = self.norm(x)
        if self.relu:
            x = F.relu(x)
        return x

I used the following code to compile -

transformer = TransformerNet()
transformer.eval()
quantized_model = compile_brevitas_qat_model(
    transformer, 
    img_tensor,  #Size of img_tensor - torch.Size([1, 3, 32, 32])
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

I have the following configurations on Mac M1(Sonoma 14.6) -

Python version:  3.9.19 (main, Mar 19 2024, 16:08:27) 
[Clang 15.0.0 (clang-1500.3.9.4)]
PyTorch version:  1.13.1
Concrete ML version:  1.6.1
Brevitas version:  0.8.0

Hello,
Could we have a complete code to reproduce the error please? You can open a GitHub repo to put all your files. And we would need the code to be as small as possible, to make the debugging faster. Thanks