import torch

import torch.nn as nn

import brevitas.nn as qnn

import torch

from torch.utils.data import TensorDataset, DataLoader

import torch.nn as nn

import torch.optim as optim

from tqdm import tqdm

import matplotlib.pyplot as plt

class TinyCNN(nn.Module):

def **init**(self, n_classes, n_bits) → None:

super().**init**()

```
a_bits = n_bits
w_bits = n_bits
self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv1 = qnn.QuantConv2d(1, 4, 3, stride=1, padding=0, weight_bit_width=w_bits)
self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv2 = qnn.QuantConv2d(4, 8, 2, stride=2, padding=0, weight_bit_width=w_bits)
self.fc1 = qnn.QuantLinear(
8 * 3 * 3,
n_classes,
bias=True,
weight_bit_width=w_bits,
)
def forward(self, x):
x = self.q1(x)
x = self.conv1(x)
x = torch.relu(x)
x = self.q2(x)
x = self.conv2(x)
x = torch.relu(x)
# Flatten the tensor before passing it to the fully connected layer
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
```

# Now the training part

torch.manual_seed(42)

def train_one_epoch(net, optimizer, train_loader):

# Cross Entropy loss for classification when not using a softmax layer in the network

loss = nn.CrossEntropyLoss()

```
net.train()
avg_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = net(data)
loss_net = loss(output, target.long())
loss_net.backward()
optimizer.step()
avg_loss += loss_net.item()
return avg_loss / len(train_loader)
```

# Create a train data loader

train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))

test_dataloader = DataLoader(test_dataset)

nets = []

bit_range = range(4, 7)

# Train the network with Adam, output the test set accuracy every epoch

losses = []

for n_bits in bit_range:

net = TinyCNN(10, n_bits)

losses_bits = []

optimizer = torch.optim.Adam(net.parameters())

for epoch in tqdm(range(N_EPOCHS), desc=f"Training with {n_bits} bit weights and activations"):

losses_bits.append(train_one_epoch(net, optimizer, train_dataloader))

```
losses.append(losses_bits)
nets.append(net)
```

fig = plt.figure(figsize=(8, 4))

for losses_bits in losses:

plt.plot(losses_bits)

plt.ylabel(“Cross Entropy Loss”)

plt.xlabel(“Epoch”)

plt.legend(list(map(str, bit_range)))

plt.title(“Training set loss during training”)

plt.grid(True)

plt.show()

def test_torch(net, n_bits, test_loader):

net.eval()

all_y_pred = np.zeros(len(test_loader), dtype=np.int64)

all_targets = np.zeros(len(test_loader), dtype=np.int64)

```
idx = 0
for data, target in test_loader:
endidx = idx + target.shape[0]
all_targets[idx:endidx] = target.numpy()
output = net(data).argmax(1).detach().numpy()
all_y_pred[idx:endidx] = output
idx += target.shape[0]
n_correct = np.sum(all_targets == all_y_pred)
print(f"Test accuracy for {n_bits}-bit weights and activations: {n_correct / len(test_loader) * 100:.2f}%")
```

# Test each network in the list

for idx, net in enumerate(nets):

test_torch(net, bit_range[idx], test_dataloader)

)

# try compiling the model

model= compile_brevitas_qat_model(nets[1], x_train,verbose=True)

It is at the point we get an empty assertion error

AssertionError: