Thank you @luis for your response. I have relocated the data permutation outside of the model and removed the pruning from my CNN. Here is my updated CNN code:
class FaceRecognitionModel_QAT(nn.Module):
#n_bits (number of bits for quantization).
def __init__(self, num_classes, n_bits):
super().__init__()
a_bits = n_bits
w_bits = n_bits
#to quantize the input by placing it at the entry point of the network.
# we have in this cnn 18179856 MAC
self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv1 = qnn.QuantConv2d(3, 32, 3, stride=1, padding=1, weight_bit_width=w_bits)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = qnn.QuantConv2d(32, 64, 3, stride=1, padding=1, weight_bit_width=w_bits)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = qnn.QuantConv2d(64, 128, 3, stride=1, padding=1, weight_bit_width=w_bits)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = qnn.QuantLinear(128*16*16,512,bias=False,weight_bit_width=w_bits)# Assuming input image size is 128x128 after pooling
self.fc2 = qnn.QuantLinear(512,num_classes,bias=False,weight_bit_width=w_bits)
# Enable pruning, prepared for training
#self.toggle_pruning(True)
""" def toggle_pruning(self, enable):
# Maximum number of active neurons (i.e., corresponding weight != 0)
n_active = 12
# Go through all the convolution layers
for layer in (self.conv1, self.conv2, self.conv3):
s = layer.weight.shape
# Compute fan-in (number of inputs to a neuron)
# and fan-out (number of neurons in the layer)
st = [s[0], np.prod(s[1:])]
# The number of input neurons (fan-in) is the product of
# the kernel width x height x inChannels.
if st[1] > n_active:
if enable:
# This will create a forward hook to create a mask tensor that is multiplied
# with the weights during forward. The mask will contain 0s or 1s
prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
else:
# When disabling pruning, the mask is multiplied with the weights
# and the result is stored in the weights member
prune.remove(layer, "weight") """
def forward(self, x):
x = self.q1(x)
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = self.pool3(F.relu(self.conv3(x)))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
The complete error message is:
[W shape_type_inference.cpp:1920] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W shape_type_inference.cpp:1920] Warning: The shape inference of onnx.brevitas::Quant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 2 WARNING 0 ERROR ========================
2 WARNING were not printed due to the log level.
---------------------------------------------------------------------------
SymbolicValueError Traceback (most recent call last)
Cell In[16], line 5
3 sim_time = []
4 for idx in range(len(bit_range)):
----> 5 q_module = compile_brevitas_qat_model(nets[idx], x_train_modified)
7 accum_bits.append(q_module.fhe_circuit.graph.maximum_integer_bit_width())
9 start_time = time.time()
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/concrete/ml/torch/compile.py:415, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
413 exporter.onnx_passes.append("eliminate_nop_pad")
414 exporter.onnx_passes.append("fuse_pad_into_conv")
--> 415 onnx_model = exporter.export(
416 torch_model,
417 args=dummy_input_for_tracing,
418 export_path=str(output_onnx_file_path),
419 keep_initializers_as_inputs=False,
420 opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
421 )
422 onnx_model = remove_initializer_from_input(onnx_model)
424 if n_bits is None:
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:159, in ONNXBaseManager.export(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
149 @classmethod
150 def export(
151 cls,
(...)
157 disable_warnings=True,
158 **onnx_export_kwargs):
--> 159 return cls.export_onnx(module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py:131, in ONNXBaseManager.export_onnx(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
129 model_bytes = BytesIO()
130 export_target = model_bytes
--> 131 torch.onnx.export(module, args, export_target, **onnx_export_kwargs)
133 # restore the model to previous properties
134 module.apply(lambda m: _restore_inp_caching_mode(m))
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:506, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
188 @_beartype.beartype
189 def export(
190 model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
(...)
206 export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
207 ) -> None:
208 r"""Exports a model into ONNX format.
209
210 If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
(...)
503 All errors are subclasses of :class:`errors.OnnxExporterError`.
504 """
--> 506 _export(
507 model,
508 args,
509 f,
510 export_params,
511 verbose,
512 training,
513 input_names,
514 output_names,
515 operator_export_type=operator_export_type,
516 opset_version=opset_version,
517 do_constant_folding=do_constant_folding,
518 dynamic_axes=dynamic_axes,
519 keep_initializers_as_inputs=keep_initializers_as_inputs,
520 custom_opsets=custom_opsets,
521 export_modules_as_functions=export_modules_as_functions,
522 )
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1548, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
1545 dynamic_axes = {}
1546 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1548 graph, params_dict, torch_out = _model_to_graph(
1549 model,
1550 args,
1551 verbose,
1552 input_names,
1553 output_names,
1554 operator_export_type,
1555 val_do_constant_folding,
1556 fixed_batch_size=fixed_batch_size,
1557 training=training,
1558 dynamic_axes=dynamic_axes,
1559 )
1561 # TODO: Don't allocate a in-memory string for the protobuf
1562 defer_weight_export = (
1563 export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
1564 )
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1117, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1114 params_dict = _get_named_param_dict(graph, params)
1116 try:
-> 1117 graph = _optimize_graph(
1118 graph,
1119 operator_export_type,
1120 _disable_torch_constant_prop=_disable_torch_constant_prop,
1121 fixed_batch_size=fixed_batch_size,
1122 params_dict=params_dict,
1123 dynamic_axes=dynamic_axes,
1124 input_names=input_names,
1125 module=module,
1126 )
1127 except Exception as e:
1128 torch.onnx.log("Torch IR graph at exception: ", graph)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:665, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
662 _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
663 _C._jit_pass_onnx_lint(graph)
--> 665 graph = _C._jit_pass_onnx(graph, operator_export_type)
666 _C._jit_pass_onnx_lint(graph)
667 _C._jit_pass_lint(graph)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py:1891, in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
1886 if symbolic_fn is not None:
1887 # TODO Wrap almost identical attrs assignment or comment the difference.
1888 attrs = {
1889 k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
1890 }
-> 1891 return symbolic_fn(graph_context, *inputs, **attrs)
1893 attrs = {
1894 k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1895 for k in node.attributeNames()
1896 }
1897 if namespace == "onnx":
1898 # Clone node to trigger ONNX shape inference
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py:306, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
300 if len(kwargs) == 1:
301 assert "_outputs" in kwargs, (
302 f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
303 f"'_outputs' key at '**kwargs'. "
304 f"{FILE_BUG_MSG}"
305 )
--> 306 return fn(g, *args, **kwargs)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:2451, in _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
2448 kernel_shape = None
2450 if kernel_shape is None or any([i is None for i in kernel_shape]):
-> 2451 raise errors.SymbolicValueError(
2452 "Unsupported: ONNX export of convolution for kernel of unknown shape.",
2453 input,
2454 )
2456 args = [input, weight]
2457 # ONNX only supports 1D bias
SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape. [Caused by the value 'x defined in (%x : Float(*, *, *, *, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.193, %scale, %zero_point, %bit_width), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx.brevitas::Quant'.]
(node defined in /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py(49): symbolic_execution
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/handler.py(114): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py(150): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/nn/quant_layer.py(148): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/tmp/ipykernel_393/1994898414.py(51): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(118): wrapper
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(127): forward
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/jit/_trace.py(1268): _get_trace_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(893): _trace_and_get_graph_from_model
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(989): _create_jit_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(1113): _model_to_graph
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(1548): _export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/onnx/utils.py(506): export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(131): export_onnx
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/manager.py(159): export
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/concrete/ml/torch/compile.py(415): compile_brevitas_qat_model
/tmp/ipykernel_393/382210858.py(5): <module>
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3508): run_code
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3064): _run_cell
/home/khalyl/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3009): run_cell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py(546): run_cell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py(422): do_execute
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(740): execute_request
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(505): process_one
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/events.py(80): _run
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/base_events.py(1906): _run_once
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/asyncio/base_events.py(603): run_forever
/home/khalyl/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py(195): start
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py(736): start
/home/khalyl/.local/lib/python3.10/site-packages/traitlets/config/application.py(1043): launch_instance
/home/khalyl/.local/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/runpy.py(86): _run_code
/home/khalyl/anaconda3/envs/myenv/lib/python3.10/runpy.py(196): _run_module_as_main
)
Inputs:
#0: x.193 defined in (%x.193 : Float(1, 3, 128, 128, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu), %conv1.weight : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu), %conv1.bias : Float(32, strides=[1], requires_grad=1, device=cpu), %conv2.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cpu), %conv2.bias : Float(64, strides=[1], requires_grad=1, device=cpu), %conv3.weight : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cpu), %conv3.bias : Float(128, strides=[1], requires_grad=1, device=cpu), %fc1.weight : Float(512, 32768, strides=[32768, 1], requires_grad=1, device=cpu), %fc2.weight : Float(158, 512, strides=[512, 1], requires_grad=1, device=cpu) = prim::Param()
) (type 'Tensor')
#1: scale defined in (%scale : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.5}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
) (type 'Tensor')
#2: zero_point defined in (%zero_point : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
) (type 'Tensor')
#3: bit_width defined in (%bit_width : Float(requires_grad=0, device=cpu) = onnx::Constant[value={2}](), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/brevitas/export/onnx/qonnx/handler.py:46:0
) (type 'Tensor')
Outputs:
#0: x defined in (%x : Float(*, *, *, *, strides=[49152, 1, 384, 3], requires_grad=0, device=cpu) = onnx.brevitas::Quant[narrow=0, rounding_mode="ROUND", signed=1](%x.193, %scale, %zero_point, %bit_width), scope: __main__.FaceRecognitionModel_QAT::/brevitas.nn.quant_activation.QuantIdentity::q1/brevitas.proxy.runtime_quant.ActQuantProxyFromInjector::act_quant/brevitas.export.onnx.qonnx.handler.BrevitasActQuantProxyHandler::export_handler # /home/khalyl/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:506:0
) (type 'Tensor')