Hi. I am trying to compile a quantized brevitas model. However, I am thrown a weird issue while compilation. Here is the complete stack trace -
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[16], line 2
1 # with torch.no_grad():
----> 2 quantized_module = compile_brevitas_qat_model(
3 transformer,
4 img_tensor,
5 rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
6 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:560, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose, inputs_encryption_status, reduce_sum_copy)
554 assert_true(
555 n_bits is None or isinstance(n_bits, (int, dict)),
556 "The n_bits parameter must be either a dictionary, an integer or None",
557 )
559 # Compile using the ONNX conversion flow, in QAT mode
--> 560 q_module = compile_onnx_model(
561 onnx_model,
562 torch_inputset,
563 n_bits=n_bits,
564 import_qat=True,
565 artifacts=artifacts,
566 show_mlir=show_mlir,
567 rounding_threshold_bits=rounding_threshold_bits,
568 configuration=configuration,
569 p_error=p_error,
570 global_p_error=global_p_error,
571 verbose=verbose,
572 inputs_encryption_status=inputs_encryption_status,
573 reduce_sum_copy=reduce_sum_copy,
574 )
576 # Remove the tempfile if we used one
577 if use_tempfile:
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:409, in compile_onnx_model(onnx_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
402 onnx_model_opset_version = get_onnx_opset_version(onnx_model)
403 assert_true(
404 onnx_model_opset_version == OPSET_VERSION_FOR_ONNX_EXPORT,
405 f"ONNX version must be {OPSET_VERSION_FOR_ONNX_EXPORT} "
406 f"but it is {onnx_model_opset_version}",
407 )
--> 409 return _compile_torch_or_onnx_model(
410 onnx_model,
411 torch_inputset,
412 import_qat,
413 configuration=configuration,
414 artifacts=artifacts,
415 show_mlir=show_mlir,
416 n_bits=n_bits,
417 rounding_threshold_bits=rounding_threshold_bits,
418 p_error=p_error,
419 global_p_error=global_p_error,
420 verbose=verbose,
421 inputs_encryption_status=inputs_encryption_status,
422 reduce_sum_copy=reduce_sum_copy,
423 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:214, in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy, composition_mapping)
208 raise ValueError(
209 "Composition must be enabled in 'configuration' in order to trigger a re-quantization "
210 "step on the circuit's outputs."
211 )
213 # Build the quantized module
--> 214 quantized_module = build_quantized_module(
215 model=model,
216 torch_inputset=inputset_as_numpy_tuple,
217 import_qat=import_qat,
218 n_bits=n_bits,
219 rounding_threshold_bits=rounding_threshold_bits,
220 reduce_sum_copy=reduce_sum_copy,
221 )
223 # Check that p_error or global_p_error is not set in both the configuration and in the direct
224 # parameters
225 check_there_is_no_p_error_options_in_configuration(configuration)
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:127, in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits, reduce_sum_copy)
121 post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
123 # Build the quantized module
124 # FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
125 # only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
126 # on the inputset we sent in the NumpyModule.
--> 127 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
129 # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
130 if reduce_sum_copy:
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:702, in ONNXConverter.quantize_module(self, *calibration_data)
699 self._quantize_layers(*calibration_data)
701 # Create quantized module from self.quant_layers_dict
--> 702 quantized_module = QuantizedModule(
703 ordered_module_input_names=(
704 graph_input.name for graph_input in self.numpy_model.onnx_model.graph.input
705 ),
706 ordered_module_output_names=(
707 graph_output.name for graph_output in self.numpy_model.onnx_model.graph.output
708 ),
709 quant_layers_dict=self.quant_ops_dict,
710 onnx_model=self.numpy_model.onnx_model,
711 )
713 adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
714 # Apply the round PBS optimization if possible
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:138, in QuantizedModule.__init__(self, ordered_module_input_names, ordered_module_output_names, quant_layers_dict, onnx_model)
136 # Initialize output quantizers based on quant_layers_dict
137 if self.quant_layers_dict:
--> 138 self.output_quantizers = self._set_output_quantizers()
139 else:
140 self.output_quantizers = []
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:280, in QuantizedModule._set_output_quantizers(self)
274 def _set_output_quantizers(self) -> List[UniformQuantizer]:
275 """Get the output quantizers.
276
277 Returns:
278 List[UniformQuantizer]: List of output quantizers.
279 """
--> 280 output_layers = list(
281 self.quant_layers_dict[output_name][1]
282 for output_name in self.ordered_module_output_names
283 )
284 output_quantizers = list(
285 QuantizedArray(
286 output_layer.n_bits,
(...)
292 for output_layer in output_layers
293 )
294 return output_quantizers
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py:281, in <genexpr>(.0)
274 def _set_output_quantizers(self) -> List[UniformQuantizer]:
275 """Get the output quantizers.
276
277 Returns:
278 List[UniformQuantizer]: List of output quantizers.
279 """
280 output_layers = list(
--> 281 self.quant_layers_dict[output_name][1]
282 for output_name in self.ordered_module_output_names
283 )
284 output_quantizers = list(
285 QuantizedArray(
286 output_layer.n_bits,
(...)
292 for output_layer in output_layers
293 )
294 return output_quantizers
KeyError: '222'
I am trying to compile the following model -
class TransformerNet(torch.nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
self.model = nn.Sequential(
ConvBlock(3, 8, kernel_size=3, stride=1),
ConvBlock(8, 16, kernel_size=3, stride=2),
ConvBlock(16, 32, kernel_size=3, stride=2),
ResidualBlock(32),
ResidualBlock(32),
ResidualBlock(32),
ResidualBlock(32),
ResidualBlock(32),
ConvBlock(32, 16, kernel_size=3),
ConvBlock(16, 8, kernel_size=3),
ConvBlock(8, 3, kernel_size=3, stride=1, normalize=False, relu=False),
)
def forward(self, x):
return self.model(x)
""" Components of Transformer Net """
class ResidualBlock(torch.nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=True),
ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=False),
)
def forward(self, x):
return self.block(x) + x
class ConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, normalize=True, relu=True):
super(ConvBlock, self).__init__()
self.kernel_size = kernel_size
self.block = nn.Sequential(
qnn.QuantIdentity(bit_width=5, return_quant_tensor=True),
qnn.QuantConv2d(in_channels, out_channels, kernel_size, stride, padding=self.kernel_size//2, return_quant_tensor=True)
)
self.norm = nn.BatchNorm2d(out_channels, affine=True) if normalize else None
self.relu = relu
def forward(self, x):
x = self.block(x)
if self.norm is not None:
x = self.norm(x)
if self.relu:
x = F.relu(x)
return x
I used the following code to compile -
transformer = TransformerNet()
transformer.eval()
quantized_model = compile_brevitas_qat_model(
transformer,
img_tensor, #Size of img_tensor - torch.Size([1, 3, 32, 32])
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)
I have the following configurations on Mac M1(Sonoma 14.6) -
Python version: 3.9.19 (main, Mar 19 2024, 16:08:27)
[Clang 15.0.0 (clang-1500.3.9.4)]
PyTorch version: 1.13.1
Concrete ML version: 1.6.1
Brevitas version: 0.8.0