I tryto get a breakdown of execution time on each layer of a neural network. But I find no way to do this. I try insert print in forward
, but that does not work.
# Import necessary libraries
import sys
import time
from datetime import datetime
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from concrete.ml.torch.compile import compile_torch_model
# Load the MNIST dataset
from mlxtend.data import mnist_data
X, y = mnist_data()
X = np.expand_dims(X.reshape((-1, 28, 28)), 1)
x_train, x_test, y_train, y_test = train_test_split(
X[:1000], y[:1000], test_size=0.25, shuffle=True, random_state=42
)
# Define the neural network
class NNX(nn.Module):
def __init__(self, x=20) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 2, (10, 11), stride=1, padding=1)
self.dense1 = nn.Linear(840, 92, bias=True)
self.dense_layers = nn.ModuleList([nn.Linear(92, 92, bias=True) for _ in range(x-3)])
self.fc = nn.Linear(92, 10)
def forward(self, x):
x = torch.relu(self.conv(x))
print(f'layer1: {datetime.now().strftime("%H:%M:%S.%f")}')
x = torch.relu(self.dense1(x.view(-1, 840)))
print(f'layer2: {datetime.now().strftime("%H:%M:%S.%f")}')
for i,layer in enumerate(self.dense_layers):
x = torch.relu(layer(x))
print(f'layer {i}: {datetime.now().strftime("%H:%M:%S.%f")}')
x = torch.relu(self.fc(x))
print(f'layer last: {datetime.now().strftime("%H:%M:%S.%f")}')
return x
# Compile the model for concrete execution
if len(sys.argv) == 2:
x = int(sys.argv[1])
print(f"x={x}")
else:
x = 20
net = NNX(x=x)
from concrete import fhe
configuration = fhe.Configuration(show_statistics=True)
n_bits = 6
model_input = np.random.rand(1, 1, 28, 28)
q_module = compile_torch_model(net, model_input, rounding_threshold_bits=n_bits, p_error=0.1, configuration=configuration)
The output is something like:
layer1: 16:48:59.200496
layer2: 16:48:59.201048
layer 0: 16:48:59.201433
layer 1: 16:48:59.201793
layer 2: 16:48:59.202138
layer 3: 16:48:59.202517
layer 4: 16:48:59.202867
layer 5: 16:48:59.203205
layer 6: 16:48:59.203536
layer 7: 16:48:59.203878
layer 8: 16:48:59.204215
layer 9: 16:48:59.204666
layer 10: 16:48:59.205037
layer 11: 16:48:59.205429
layer 12: 16:48:59.205769
layer 13: 16:48:59.206118
layer 14: 16:48:59.206453
layer 15: 16:48:59.206781
layer 16: 16:48:59.207111
layer 17: 16:48:59.207464
layer 18: 16:48:59.207795
layer 19: 16:48:59.208140
layer last: 16:48:59.208844
layer fc: 16:48:59.222547
All layers’ timestamp is almost the same, cannot be the timestamp when the layer begins to execute.
I feel it is difficult, because in my understanding, the whole network is compiled together, there is no bound between each layer.
If there is a way, I would to like to make changes on the compiler or the tfhe-rs and recompile.