I am working on a project that involves text-based models, and was wondering how to use array slicing in Concrete ML (e.g. x[:,i,:]). The goal is to take each one-hot vector at the second index i, feed them through an embedding table and concatenate the output.
It trains with the slicing but fails to compile in brevitas.
Please let me know if it is possible to split and concatenate layer inputs/output in Concrete ML.
class TinyMLP(nn.Module):
def __init__(self, n_neurons, n_bits) -> None:
super().__init__()
self.q1 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc1 = qnn.QuantLinear(27, 10, bias=False, weight_bit_width=n_bits)
self.q2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(30, n_neurons, bias=True, weight_bit_width=n_bits)
self.q4 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc4 = qnn.QuantLinear(n_neurons, 27, bias=True, weight_bit_width=n_bits)
def forward(self, x):
"""Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""
x_pre = self.q1(x[:, 0, :])
x_pre = self.fc1(x_pre)
for i in range(1,self.n_blocks):
x_pre = torch.cat((x_pre, self.fc1(self.q1(x[:,i,:]))), axis = 1)
x = self.q2(x_pre)
x = self.fc2(x)
x = torch.relu(x)
x = self.q4(x)
x = self.fc4(x)
return x
My Version
concrete-compiler 0.23.4
concrete-ml 1.0.0
concrete-ml-extensions-brevitas 0.1.0
concrete-numpy 0.9.0
concrete-python 1.0.0
Error Message
Traceback (most recent call last):
File "~Concrete/names/array_slicing.py", line 282, in <module>
q_module = compile_brevitas_qat_model(nets[idx], x_train.float(), n_bits=idx+3)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/torch/compile.py", line 416, in compile_brevitas_qat_model
q_module = compile_onnx_model(
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/torch/compile.py", line 263, in compile_onnx_model
return _compile_torch_or_onnx_model(
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/torch/compile.py", line 125, in _compile_torch_or_onnx_model
quantized_module.compile(
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py", line 576, in compile
self.fhe_circuit = compiler.compile(
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/fhe/compilation/compiler.py", line 434, in compile
self._evaluate("Compiling", inputset)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/fhe/compilation/compiler.py", line 279, in _evaluate
self._trace(first_sample)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/fhe/compilation/compiler.py", line 207, in _trace
self.graph = Tracer.trace(self.function, parameters)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/fhe/tracing/tracer.py", line 77, in trace
output_tracers: Any = function(**arguments)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/common/utils.py", line 1, in _clear_forward_proxy
"""Utils that can be re-used by other pieces of code in the module."""
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/quantization/quantized_module.py", line 367, in _clear_forward
output = layer(*inputs)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/quantization/base_quantized_op.py", line 236, in __call__
return self.q_impl(*q_inputs, **self.attrs)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/quantization/quantized_ops.py", line 2155, in q_impl
self.call_impl(*inputs, **attrs),
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/quantization/base_quantized_op.py", line 623, in call_impl
outputs = impl_func(*inputs) if not self._has_attr else impl_func(*inputs, **attrs)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/ml/onnx/ops_impl.py", line 1942, in numpy_gather
return (x[tuple(slices)],)
File "~.conda/envs/concrete/lib/python3.9/site-packages/concrete/fhe/tracing/tracer.py", line 750, in __getitem__
raise ValueError(message)
ValueError: Indexing with '0' is not supported