Hi. When I try to import an ONNX model I get an error without any explanation. Here is the complete stack trace -
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[13], line 4
1 onnx_model = onnx.load("model.onnx")
2 onnx.checker.check_model(onnx_model)
----> 4 private_model = compile_onnx_model(
5 onnx_model,
6 img_tensor,
7 n_bits=5,
8 rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
9 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:409, in compile_onnx_model(onnx_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
402 onnx_model_opset_version = get_onnx_opset_version(onnx_model)
403 assert_true(
404 onnx_model_opset_version == OPSET_VERSION_FOR_ONNX_EXPORT,
405 f"ONNX version must be {OPSET_VERSION_FOR_ONNX_EXPORT} "
406 f"but it is {onnx_model_opset_version}",
407 )
--> 409 return _compile_torch_or_onnx_model(
410 onnx_model,
411 torch_inputset,
412 import_qat,
413 configuration=configuration,
414 artifacts=artifacts,
415 show_mlir=show_mlir,
416 n_bits=n_bits,
417 rounding_threshold_bits=rounding_threshold_bits,
418 p_error=p_error,
419 global_p_error=global_p_error,
420 verbose=verbose,
421 inputs_encryption_status=inputs_encryption_status,
422 reduce_sum_copy=reduce_sum_copy,
423 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:214, in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy, composition_mapping)
208 raise ValueError(
209 "Composition must be enabled in 'configuration' in order to trigger a re-quantization "
210 "step on the circuit's outputs."
211 )
213 # Build the quantized module
--> 214 quantized_module = build_quantized_module(
215 model=model,
216 torch_inputset=inputset_as_numpy_tuple,
217 import_qat=import_qat,
218 n_bits=n_bits,
219 rounding_threshold_bits=rounding_threshold_bits,
220 reduce_sum_copy=reduce_sum_copy,
221 )
223 # Check that p_error or global_p_error is not set in both the configuration and in the direct
224 # parameters
225 check_there_is_no_p_error_options_in_configuration(configuration)
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/torch/compile.py:127, in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits, reduce_sum_copy)
121 post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
123 # Build the quantized module
124 # FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
125 # only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
126 # on the inputset we sent in the NumpyModule.
--> 127 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
129 # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
130 if reduce_sum_copy:
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:699, in ONNXConverter.quantize_module(self, *calibration_data)
696 # First transform all parameters to their quantized version
697 self._quantize_params()
--> 699 self._quantize_layers(*calibration_data)
701 # Create quantized module from self.quant_layers_dict
702 quantized_module = QuantizedModule(
703 ordered_module_input_names=(
704 graph_input.name for graph_input in self.numpy_model.onnx_model.graph.input
(...)
710 onnx_model=self.numpy_model.onnx_model,
711 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:627, in ONNXConverter._quantize_layers(self, *input_calibration_data)
618 self.quant_ops_dict[output_name] = (
619 tuple(variable_input_names),
620 quantized_op_instance,
621 )
623 layer_quant = list(
624 node_override_quantizer.get(input_name, None)
625 for input_name in variable_input_names
626 )
--> 627 output_calibration_data, layer_quantizer = self._process_layer(
628 quantized_op_instance, *curr_calibration_data, quantizers=layer_quant
629 )
630 node_results[output_name] = output_calibration_data
631 node_override_quantizer[output_name] = layer_quantizer
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:873, in PostTrainingAffineQuantization._process_layer(self, quantized_op, quantizers, *calibration_data)
852 def _process_layer(
853 self,
854 quantized_op: QuantizedOp,
855 *calibration_data: numpy.ndarray,
856 quantizers: List[Optional[UniformQuantizer]],
857 ) -> Tuple[numpy.ndarray, Optional[UniformQuantizer]]:
858 """Configure a graph operation by performing calibration for uniform quantization.
859
860 Args:
(...)
870 numpy.ndarray: calibration data for the following operators
871 """
--> 873 return self._calibrate_layers_activation(
874 True, quantized_op, *calibration_data, quantizers=quantizers
875 )
File /opt/homebrew/lib/python3.9/site-packages/concrete/ml/quantization/post_training.py:393, in ONNXConverter._calibrate_layers_activation(self, calibrate_quantized, quantized_op, quantizers, *calibration_data)
390 # For PTQ, the calibration is performed on quantized data. But
391 # raw operation output (RawOpOutput) data should not be quantized
392 if calibrate_quantized and not isinstance(quant_result, RawOpOutput):
--> 393 assert isinstance(quant_result, QuantizedArray)
394 return (
395 quant_result.dequant(),
396 quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None,
397 )
399 # For QAT, the calibration is performed on raw data, performing
400 # calibration on quantized that would confound inferred QAT and PTQ.
AssertionError:
I have check my ONNX model contains all operators that are supported in concrete ML. I am kind of lost here. Any help would be appreciated. Code used to compile the model -
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
private_model = compile_onnx_model(
onnx_model,
img_tensor,
n_bits=5,
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)