Hi, I have a question regarding the handling of memory when compiling QAT models.
I tried compiling a simple convolution QAT model with a quite large calibration dataset (tensor of 15000 * 15000), and my kernel always dies.
Is there something I am doing wrong?
How much RAM would I need to handle the compilation of this model?
Here is some minimal code to reproduce the problem:
import brevitas.nn as qnn
import numpy as np
import torch
import torch.nn as nn
from concrete.ml.torch.compile import compile_brevitas_qat_model
weight = torch.randn(1, 1, 1, 1000)
class SimpleConvBrevitas(nn.Module):
def __init__(self, bit_width):
self.id1 = qnn.QuantIdentity(bit_width=bit_width)
self.conv = qnn.QuantConv2d(1, 1, 1, bit_width=bit_width, bias=False)
self.conv.weight = nn.Parameter(weight)
def forward(self, x):
"""Forward pass of the model."""
x = self.id1(x)
x = self.conv(x)
return x
tensor_ = torch.randn(1, 1, 15000, 15000)
array_ = np.random.randn(1, 1, 15000, 15000)
model = SimpleConvBrevitas(bit_width=8)
compiled_module = compile_brevitas_qat_model(model, tensor_, verbose=True, n_bits=8)
Thanks for any help you can provide on this issue!