import torch

import torch.nn as nn

import brevitas.nn as qnn

from torch.utils.data import TensorDataset, DataLoader

import torch.optim as optim

from tqdm import tqdm

import matplotlib.pyplot as plt

class TinyCNN(nn.Module):

def **init**(self, n_classes, n_bits) → None:

super().**init**()

```
a_bits = n_bits
w_bits = n_bits
self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv1 = qnn.QuantConv2d(1, 4, 3, stride=1, padding=0, weight_bit_width=w_bits)
self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv2 = qnn.QuantConv2d(4, 8, 2, stride=2, padding=0, weight_bit_width=w_bits)
self.fc1 = qnn.QuantLinear(
8 * 3 * 3,
n_classes,
bias=True,
weight_bit_width=w_bits,
)
def forward(self, x):
x = self.q1(x)
x = self.conv1(x)
x = torch.relu(x)
x = self.q2(x)
x = self.conv2(x)
x = torch.relu(x)
# Flatten the tensor before passing it to the fully connected layer
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
```

# Now the training part

torch.manual_seed(42)

def train_one_epoch(net, optimizer, train_loader):

# Cross Entropy loss for classification when not using a softmax layer in the network

loss = nn.CrossEntropyLoss()

```
net.train()
avg_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = net(data)
loss_net = loss(output, target.long())
loss_net.backward()
optimizer.step()
avg_loss += loss_net.item()
return avg_loss / len(train_loader)
```

# Create a train data loader

train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))

test_dataloader = DataLoader(test_dataset)

nets = []

bit_range = range(4, 7)

# Train the network with Adam, output the test set accuracy every epoch

losses = []

for n_bits in bit_range:

net = TinyCNN(10, n_bits)

losses_bits = []

optimizer = torch.optim.Adam(net.parameters())

for epoch in tqdm(range(N_EPOCHS), desc=f"Training with {n_bits} bit weights and activations"):

losses_bits.append(train_one_epoch(net, optimizer, train_dataloader))

```
losses.append(losses_bits)
nets.append(net)
```

fig = plt.figure(figsize=(8, 4))

for losses_bits in losses:

plt.plot(losses_bits)

plt.ylabel(“Cross Entropy Loss”)

plt.xlabel(“Epoch”)

plt.legend(list(map(str, bit_range)))

plt.title(“Training set loss during training”)

plt.grid(True)

plt.show()

def test_torch(net, n_bits, test_loader):

net.eval()

all_y_pred = np.zeros(len(test_loader), dtype=np.int64)

all_targets = np.zeros(len(test_loader), dtype=np.int64)

```
idx = 0
for data, target in test_loader:
endidx = idx + target.shape[0]
all_targets[idx:endidx] = target.numpy()
output = net(data).argmax(1).detach().numpy()
all_y_pred[idx:endidx] = output
idx += target.shape[0]
n_correct = np.sum(all_targets == all_y_pred)
print(f"Test accuracy for {n_bits}-bit weights and activations: {n_correct / len(test_loader) * 100:.2f}%")
```

# Test each network in the list

for idx, net in enumerate(nets):

test_torch(net, bit_range[idx], test_dataloader)

)

# try compiling the model

model= compile_brevitas_qat_model(nets[1], x_train,verbose=True)

It is at the point we get an empty assertion error

AssertionError: