import torch
import torch.nn as nn
import brevitas.nn as qnn
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
class TinyCNN(nn.Module):
def init(self, n_classes, n_bits) → None:
super().init()
a_bits = n_bits
w_bits = n_bits
self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv1 = qnn.QuantConv2d(1, 4, 3, stride=1, padding=0, weight_bit_width=w_bits)
self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
self.conv2 = qnn.QuantConv2d(4, 8, 2, stride=2, padding=0, weight_bit_width=w_bits)
self.fc1 = qnn.QuantLinear(
8 * 3 * 3,
n_classes,
bias=True,
weight_bit_width=w_bits,
)
def forward(self, x):
x = self.q1(x)
x = self.conv1(x)
x = torch.relu(x)
x = self.q2(x)
x = self.conv2(x)
x = torch.relu(x)
# Flatten the tensor before passing it to the fully connected layer
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
Now the training part
torch.manual_seed(42)
def train_one_epoch(net, optimizer, train_loader):
# Cross Entropy loss for classification when not using a softmax layer in the network
loss = nn.CrossEntropyLoss()
net.train()
avg_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = net(data)
loss_net = loss(output, target.long())
loss_net.backward()
optimizer.step()
avg_loss += loss_net.item()
return avg_loss / len(train_loader)
Create a train data loader
train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))
test_dataloader = DataLoader(test_dataset)
nets = []
bit_range = range(4, 7)
Train the network with Adam, output the test set accuracy every epoch
losses = []
for n_bits in bit_range:
net = TinyCNN(10, n_bits)
losses_bits = []
optimizer = torch.optim.Adam(net.parameters())
for epoch in tqdm(range(N_EPOCHS), desc=f"Training with {n_bits} bit weights and activations"):
losses_bits.append(train_one_epoch(net, optimizer, train_dataloader))
losses.append(losses_bits)
nets.append(net)
fig = plt.figure(figsize=(8, 4))
for losses_bits in losses:
plt.plot(losses_bits)
plt.ylabel(“Cross Entropy Loss”)
plt.xlabel(“Epoch”)
plt.legend(list(map(str, bit_range)))
plt.title(“Training set loss during training”)
plt.grid(True)
plt.show()
def test_torch(net, n_bits, test_loader):
net.eval()
all_y_pred = np.zeros(len(test_loader), dtype=np.int64)
all_targets = np.zeros(len(test_loader), dtype=np.int64)
idx = 0
for data, target in test_loader:
endidx = idx + target.shape[0]
all_targets[idx:endidx] = target.numpy()
output = net(data).argmax(1).detach().numpy()
all_y_pred[idx:endidx] = output
idx += target.shape[0]
n_correct = np.sum(all_targets == all_y_pred)
print(f"Test accuracy for {n_bits}-bit weights and activations: {n_correct / len(test_loader) * 100:.2f}%")
Test each network in the list
for idx, net in enumerate(nets):
test_torch(net, bit_range[idx], test_dataloader)
)
try compiling the model
model= compile_brevitas_qat_model(nets[1], x_train,verbose=True)
It is at the point we get an empty assertion error
AssertionError: