Compile_brevitas_qat_model PROBLEM

Hi to everyone:
I have this error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[107], line 3
      1 from concrete.ml.torch.compile import compile_brevitas_qat_model
----> 3 qmodel = compile_brevitas_qat_model(
      4             # Quantized Pytorch model using Brevitas
      5             torch_model=model,
      6             # Representative data-set for compilation and calibration
      7             torch_inputset=input_img,
      8             n_bits=10,
      9             verbose_compilation=True  
     10         )
     12 # Check the maximum bit-width of your model
     13 qmodel.fhe_circuit.graph.maximum_integer_bit_width()

File /usr/local/lib/python3.9/site-packages/concrete/ml/torch/compile.py:349, in compile_brevitas_qat_model(torch_model, torch_inputset, n_bits, configuration, compilation_artifacts, show_mlir, use_virtual_lib, p_error, global_p_error, output_onnx_file, verbose_compilation)
    346 onnx_model = remove_initializer_from_input(onnx_model)
    348 # Compile using the ONNX conversion flow, in QAT mode
--> 349 q_module_vl = compile_onnx_model(
    350     onnx_model,
    351     torch_inputset,
    352     n_bits=n_bits,
    353     import_qat=True,
    354     compilation_artifacts=compilation_artifacts,
    355     show_mlir=show_mlir,
...
     22 """
     24 if not condition:
---> 25     raise error_type(on_error_msg)

AssertionError: Values must be float if value_is_float is set to True, got int64: [ 1 -1]

this is the code:

import torch
from mnist_cnn import MNISTCNN
use_cuda=True

model = MNISTCNN()
#model.load_state_dict(torch.load("mnist-cnn")['model_state'])
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

# Initialize the network
model = model.to(device)

# Load the pretrained model
model.load_state_dict(torch.load("mnist-cnn", map_location='cpu')['model_state'])
model
Output:
MNISTCNN(
  (features): Sequential(
    (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2))
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu3): ReLU(inplace=True)
    (flatten): Flatten()
    (fc1): Linear(in_features=576, out_features=32, bias=True)
    (relu4): ReLU(inplace=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (flatten2): Flatten()
  )
  (classifier): Linear(in_features=32, out_features=10, bias=True)
)
mean, std, var = torch.mean(t), torch.std(t), torch.var(t)
print("Mean, Std and Var before Normalize:\n", 
      mean, std, var)
  

transformer = transforms.Compose([
                                  transforms.Resize((28,28)),
                                  #transforms.ToTensor(),
                                  transforms.Normalize(mean, std) 
                                  ])
image_tensor = transformer(data)
image_tensor = image_tensor.unsqueeze_(1)
input_imgs = Variable(image_tensor)
input_imgs.size() #torch.Size([50, 1, 28, 28])
from concrete.ml.torch.compile import compile_brevitas_qat_model

qmodel = compile_brevitas_qat_model(
            # Quantized Pytorch model using Brevitas
            torch_model=model,
            # Representative data-set for compilation and calibration
            torch_inputset=input_imgs,
            n_bits=10,
            verbose_compilation=True  
        )

I don’t understand the problem…
thank you so much.

I can also share the notebook in case you need it.

The model you are using does not seem to be a QAT model (one that uses brevitas quantized layers). In this case I would suggest you use compile_torch_model. However, with n_bits=10 will encounter compilation errors because the “accumulator bitwidth” will be too high. You will need to strongly lower n_bits to use compile_torch_model.

If you still get the same error, please post the full stack trace (expand the jupyter nb cell with the trace before copying) or post the full notebook.