Tried to create a smaller qat model with mnist (throws off key error)
net=SmallerCNN(n_classes=10, n_bits=6)
model_dev = compile_brevitas_qat_model( net,x_train)
KeyError Traceback (most recent call last)
in <cell line: 2>()
1 net=SmallerCNN(n_classes=10, n_bits=7)
----> 2 model_dev = compile_brevitas_qat_model( net,x_train)
26 frames
/usr/local/lib/python3.10/dist-packages/concrete/ml/torch/compile.py in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, artifacts, show_mlir, rounding_threshold_bits, p_error, global_p_error, output_onnx_file, verbose)
413 exporter.onnx_passes.append(“eliminate_nop_pad”)
414 exporter.onnx_passes.append(“fuse_pad_into_conv”)
→ 415 onnx_model = exporter.export(
416 torch_model,
417 args=dummy_input_for_tracing,
/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/manager.py in export(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
157 disable_warnings=True,
158 **onnx_export_kwargs):
→ 159 return cls.export_onnx(module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/manager.py in export_onnx(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)
129 model_bytes = BytesIO()
130 export_target = model_bytes
→ 131 torch.onnx.export(module, args, export_target, **onnx_export_kwargs)
132
133 # restore the model to previous properties
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
502 “”"
503
→ 504 _export(
505 model,
506 args,
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
1527 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1528
→ 1529 graph, params_dict, torch_out = _model_to_graph(
1530 model,
1531 args,
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1109
1110 model = _pre_trace_quant_model(model, args)
→ 1111 graph, params, torch_out, module = _create_jit_graph(model, args)
1112 params_dict = _get_named_param_dict(graph, params)
1113
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _create_jit_graph(model, args)
985 return graph, params, torch_out, None
986
→ 987 graph, torch_out = _trace_and_get_graph_from_model(model, args)
988 _C._jit_pass_onnx_lint(graph)
989 state_dict = torch.jit._unique_state_dict(model)
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args)
889 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
890 torch.set_autocast_cache_enabled(False)
→ 891 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
892 model,
893 args,
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1182 if not isinstance(args, tuple):
1183 args = (args,)
→ 1184 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1185 return outs
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in forward(self, *args)
125 return tuple(out_vars)
126
→ 127 graph, out = torch._C._create_graph_by_tracing(
128 wrapper,
129 in_vars + module_state,
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in wrapper(*args)
116 if self._return_inputs_states:
117 inputs_states.append(_unflatten(in_args, in_desc))
→ 118 outs.append(self.inner(*trace_inputs))
119 if self._return_inputs_states:
120 inputs_states[0] = (inputs_states[0], trace_inputs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:
in forward(self, x)
57 x = torch.relu(x)
58 x = self.q2(x)
—> 59 x = self.conv2(x)
60 x = torch.relu(x)
61
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1210 input = bw_hook.setup_input_hook(input)
1211
→ 1212 result = forward_call(input, **kwargs)
1213 if _global_forward_hooks or self._forward_hooks:
1214 for hook in (_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:
/usr/local/lib/python3.10/dist-packages/brevitas/nn/quant_conv.py in forward(self, input)
187
188 def forward(self, input: Union[Tensor, QuantTensor]) → Union[Tensor, QuantTensor]:
→ 189 return self.forward_impl(input)
190
191 def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
/usr/local/lib/python3.10/dist-packages/brevitas/nn/quant_layer.py in forward_impl(self, inp)
315
316 quant_input = self.input_quant(inp)
→ 317 quant_weight = self.quant_weight()
318
319 if quant_input.bit_width is not None:
/usr/local/lib/python3.10/dist-packages/brevitas/nn/mixin/parameter.py in quant_weight(self)
53
54 def quant_weight(self):
—> 55 return self.weight_quant(self.weight)
56
57 def int_weight(self, float_datatype=False):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:
/usr/local/lib/python3.10/dist-packages/brevitas/proxy/parameter_quant.py in forward(self, x)
84 if self.is_quant_enabled:
85 impl = self.export_handler if self.export_mode else self.tensor_quant
—> 86 out, scale, zero_point, bit_width = impl(x)
87 return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
88 else: # quantization disabled
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = ,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1180 recording_scopes = False
1181 try:
→ 1182 result = self.forward(*input, **kwargs)
1183 finally:
1184 if recording_scopes:
/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/handler.py in forward(self, inp, *args, **kwargs)
112 if self.export_debug_name is not None and self.debug_input:
113 inp = debug_fn(inp, ‘.input’)
→ 114 out = self.symbolic_execution(inp, *args, **kwargs)
115 if self.export_debug_name is not None and self.debug_output:
116 if isinstance(out, Tensor):
/usr/local/lib/python3.10/dist-packages/brevitas/export/onnx/qonnx/handler.py in symbolic_execution(self, x)
67
68 def symbolic_execution(self, x: Tensor):
—> 69 quant_weight = self.quant_weights[x.data_ptr()]
70 return super().symbolic_execution(quant_weight)
71
KeyError: 97488528374272