Question on PBS Number When Implementing NN-20 Based on the PBS Whitepaper

Hello community,

I have implemented NN-20 following the guidelines provided in this link. Below is the code I used:

# Import necessary libraries
import sys
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from concrete.ml.torch.compile import compile_torch_model

# Load the MNIST dataset
from mlxtend.data import mnist_data
X, y = mnist_data()
X = np.expand_dims(X.reshape((-1, 28, 28)), 1)
x_train, x_test, y_train, y_test = train_test_split(
    X[:1000], y[:1000], test_size=0.25, shuffle=True, random_state=42
)

# Define the neural network
class NNX(nn.Module):
    def __init__(self, x=20) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 2, (10, 11), stride=1, padding=1)
        self.dense1 = nn.Linear(840, 92, bias=True)
        self.dense_layers = nn.ModuleList([nn.Linear(92, 92, bias=True) for _ in range(x-3)])
        self.fc = nn.Linear(92, 10)

    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = torch.relu(self.dense1(x.view(-1, 840)))
        for layer in self.dense_layers:
            x = torch.relu(layer(x))
        x = torch.relu(self.fc(x))
        return x

# Compile the model for concrete execution
if len(sys.argv) == 2:
    x = int(sys.argv[1])
    print(f"x={x}")
else:
    x = 20

net = NNX(x=x)
from concrete import fhe
configuration = fhe.Configuration(show_statistics=True)
n_bits = 6

model_input = np.random.rand(1, 1, 28, 28)
q_module = compile_torch_model(net, model_input, rounding_threshold_bits=n_bits, p_error=0.1, configuration=configuration)

I want to highlight that, in my understanding, the number 20 does not represent the count of dense layers with output size 92, but rather the total count of both convolutional and dense layers. Therefore, between the first convolutional layer and the last dense layer for classification, there are 18 dense layers with output size 92.

However, the resulting statistics are puzzling:

Statistics
------------------------------------------------------------------------
size_of_secret_keys: 51240
size_of_bootstrap_keys: 0
size_of_keyswitch_keys: 0
size_of_inputs: 8034432
size_of_outputs: 163920
p_error: 0.09957498721402484
global_p_error: 1.0
complexity: 1930082497588.0
programmable_bootstrap_count: 35084
programmable_bootstrap_count_per_parameter: {
    BootstrapKeyParam(polynomial_size=256, glwe_dimension=5, input_lwe_dimension=512, level=2, base_log=16, variance=7.177464159383647e-31): 32578
    BootstrapKeyParam(polynomial_size=1024, glwe_dimension=2, input_lwe_dimension=512, level=14, base_log=3, variance=4.70197740328915e-38): 2414
    BootstrapKeyParam(polynomial_size=2048, glwe_dimension=1, input_lwe_dimension=517, level=8, base_log=5, variance=4.70197740328915e-38): 92
}
key_switch_count: 37580
key_switch_count_per_parameter: {
    KeyswitchKeyParam(level=1, base_log=9, variance=3.657038691888256e-12): 32578
    KeyswitchKeyParam(level=4, base_log=4, variance=3.657038691888256e-12): 2414
    KeyswitchKeyParam(level=5, base_log=8, variance=7.177464159383647e-31): 2404
    KeyswitchKeyParam(level=6, base_log=3, variance=2.7627281868226664e-12): 92
    KeyswitchKeyParam(level=3, base_log=12, variance=7.177464159383647e-31): 92
}
packing_key_switch_count: 0
clear_addition_count: 70168
clear_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=1280): 70168
}
encrypted_addition_count: 347066
encrypted_addition_count_per_parameter: {
    LweSecretKeyParam(dimension=1280): 347066
}
clear_multiplication_count: 347066
clear_multiplication_count_per_parameter: {
    LweSecretKeyParam(dimension=1280): 347066
}
encrypted_negation_count: 32578
encrypted_negation_count_per_parameter: {
    LweSecretKeyParam(dimension=1280): 32578
}
------------------------------------------------------------------------

The statistics indicate that the PBS number is 35740. This is quite surprising because, according to the link, the PBS number should be 840 + 92*18 = 2496, which is more than ten times less than the result. Could someone explain why there is such a discrepancy?

Thank you in advance for your insights.

Hi,

Relu can be rewritten as many PBS by Concrete Python with at least 1 PBS per bits of the input and then some PBS to handle bits by chunks. I am not sure it’s the case here, but it could explain what you see.

This can be controlled by relu_on_bits_threshold when calling fhe.Configuration.
See relu doc

Hello @IsaacQuebec ,
Do you confirm that you are using one of Concrete ML’s latest version ? We’ve just released version 1.5.0 last week so you might want to try it out.

However, as @Rudy_Sicard mentioned, things have changed quite a bit since the post you are referencing to (2-3 years ago) and several speed-up optimizations are done under the hood. In particular, some of them do require some additional 1-bit PBS, like the rounding feature that you are using.

Still, it’s true that this number might be a bit high even after knowing this. So if you confirm that you are still seeing this when using our latest release, we’ll take a look at it !

I set relu_on_bits_threshold to 17. I thought then PBS number would decrease, but in fact there is no change, which confuses me.

As rudy mentioned there are 1b PBS that are inserted in this computation graph, for each RELU.

Suppose a dense layer computes Wx where W is the weight matrix and x is the input activation vector. The bitwidth of the maximal value in the resulting vector is determined by taking the log2.

The number of 1b PBS that is added for each relu is bitwidth - rounding_treshold_bits. Thus you have many more PBS.

You can do something else to remove those PBS at the expense of losing full correctness:

rounding_threshold_bits={'n_bits': n_bits, 'method': 'approximate'

See Using Torch | Concrete ML