Memory issues cause kernel to die when compiling

Hi, I have a question regarding the handling of memory when compiling QAT models.
I tried compiling a simple convolution QAT model with a quite large calibration dataset (tensor of 15000 * 15000), and my kernel always dies.
Is there something I am doing wrong?
How much RAM would I need to handle the compilation of this model?

Here is some minimal code to reproduce the problem:

import brevitas.nn as qnn
import numpy as np
import torch
import torch.nn as nn
from concrete.ml.torch.compile import compile_brevitas_qat_model

weight = torch.randn(1, 1, 1, 1000)


class SimpleConvBrevitas(nn.Module):

    def __init__(self, bit_width):
        super().__init__()
        self.id1 = qnn.QuantIdentity(bit_width=bit_width)
        self.conv = qnn.QuantConv2d(1, 1, 1, bit_width=bit_width, bias=False)
        self.conv.weight = nn.Parameter(weight)

    def forward(self, x):
        """Forward pass of the model."""
        x = self.id1(x)
        x = self.conv(x)
        return x


tensor_ = torch.randn(1, 1, 15000, 15000)
array_ = np.random.randn(1, 1, 15000, 15000)
model = SimpleConvBrevitas(bit_width=8)
model(tensor_)
compiled_module = compile_brevitas_qat_model(model, tensor_, verbose=True, n_bits=8)

Thanks for any help you can provide on this issue!

Hi!

You should try to decrease the number of example you give to the compile method. Giving a subset should be enough in most cases as we are only interested by the min / max of every intermediate value per layer in your model.

If you are not sure whether the input set you gave was enough you can compare the accuracy of your model on the inputset vs another batch of example which were not in the input set. If there a noticeable difference then the inputset could be at fault.

Okay, I will do that, thank you for your answer! I was unsure whether there was an option in Concrete to limit the memory load during compilation.