Tensor encryption latency behaves differently based on model split point

Hello, I was trying to find the best split point on a model between client and server and I found that different image tensors have different encryption latency per image depending on the split point.

To be more precise, the model is made of brevitas layers. The client executes the first part of the layers. Then, I take the client output, quantize it with the QuantizedModule and encrypt it with the FHE Circuit. While I was doing this, I monitored the mean encryprion latency per image and I found that it drastically changes depending on the last layer of the client:

  • AvgPool2d and Conv2d outputs show encryption latencies per image that grow linearly with the dimension of the tensor, but Conv2d takes 5x more time to encrypt a sample (e.g. #elements_in_tensor = 25008, conv_latency=20sec, avgpool_latency=4sec, approximately)
  • ReLU outputs behave weirdly: in the beginning the latencies were low but inconsistent (4, 8, 5 sec per same tensor size), then the tensor size doubled and the latency spiked to 73sec.

I immediately thought that this behaviour is suspicious. Is there something affecting this that I’m not aware of? Or is there some implementation error?
Please, help…

Ideally you should be able to print the fhe_circuit this will give us more information. I am not sure how you are splitting the layers exactly.

What I can see here is that AvgPool2d and Conv2d could be only linear operations. While ReLU contains the costly PBS operation. But still, it should grow linearly as the input size increases unless the input bit-width changes with the tensor size ?

I have a VGG model with 23 layers (not counting the QuantIdentity layers) that is organized as two nn.Sequential called client_features and server_features.
The client side is just quantized, while the server side is encrypted.
I made 22 experiments: in the first experiment the client executed all the layers except the last one, in the second experiment the client executed all the layers except the last 2, and so on. For each experiment, I computed the mean encryption latency per image with the following code:


# X_fhe_test.shape = [120, 3, 32, 32]

client_time = []
server_input = []
# Client forward latency for each sample
for sample in X_fhe_test:
    start_time = time()
    client_output = quant_vgg.client_features.to('cpu').forward(sample)
    client_time.append(time() - start_time)
    server_input.append(client_output.value.detach())

# Discard the first 20 latencies and compute the mean
top_client_time = client_time[20:]
mean_forward_time = round(np.mean(top_client_time), 4)

# Define server input as torch tensor
server_input = torch.stack(server_input)

# quantIdentity because, depending on where I split the model, it may 
# not start with a quantIdentity
model = nn.Sequential(
        qnn.QuantIdentity(bit_width=n_bits, act_quant=act_quant, return_quant_tensor=True),
        quant_vgg.server_features
)

# Quantized module
qmodel = compile_brevitas_qat_model(
    torch_model=model.to('cpu'),
    torch_inputset=server_input,
    n_bits=n_bits,
    rounding_threshold_bits=round_bits,
    p_error=p_error,
    device=compilation_device,
)
print("QuantizedModule compiled.")

server_input = server_input.numpy()
first_batch = server_input[:1]

# FHE Circuit
fhe_circuit = qmodel.compile(
    inputs=first_batch,
)
print("FHE circuit compiled.")

# Client encryption time
encryption_time = []
for sample in server_input:
    quant_sample = qmodel.quantize_input([sample])
    start_time = time()
    _ = fhe_circuit.encrypt(quant_sample)
    encryption_time.append(time() - start_time)

# Discard the first 20 encryption latencies and compute the mean
top_encryption_time = encryption_time[20:]
mean_encryption_time = round(np.mean(top_encryption_time), 4)

What I noticed is that:

  • if I encrypt the output of a Conv2d, the encryption latency is kind of proportional to the size of the output, e.g.
    encrypted_MB=0.57 -> t_encrypt= approx. 19 sec
    encrypted_MB=1.32 -> t_encrypt= approx. 40 sec
  • if I encrypt the output of a AvgPool, the encryption latency is kind of proportional to the size of the output, e.g.
    encrypted_MB=0.57 -> t_encrypt= approx. 5 sec
    encrypted_MB=0.29 -> t_encrypt= approx. 2 sec
  • if I encrypt the output of a ReLU, the encryption latency is unstable, e.g.
    encrypted_MB=0.57 -> t_encrypt= approx. 5 sec thrice, approx. 9 sec once
    encrypted_MB=1.32 -> t_encrypt= approx. 9 sec once, approx 73 sec once
    encrypted_MB=2.82 -> t_encrypt= approx. 157 sec
    encrypted_MB=1.50 -> t_encrypt= approx. 85 sec

NB: when I say thrice and once, I’m referring to different experiments in which I split the VGG in different points and the tensor dimension happened to be the same.

What I wanted to know is if this behaviour is correct or if there is some implementation error (if you see it, where?).
If so, what does the tensor encryption latency depend on? Why are Conv2d and AvgPool latencies so different, despite having the same tensor size?
Why do ReLU’s outputs behave so strangely?

(I deleted the previous reply because it was a bit confusing)

Thanks for the information!

I don’t see any problem with your code.

One thing that’s important here is that you don’t explicitly do the keygen so when you call encrypt the keygen could be part of the timing your record. We have a system of caching for keys. Some of them might be reused so no keygen to be made.

You could do model.fhe_circuit.keygen() before the encryption and record that time separately.

If the timing are still very different we can check further and see the qmodel.fhe_circuit.statistics.

Thanks for your reply!

As you told me, I separated the key generation step before the encryption. The first latency is significantly faster, this is because the key was generated in the first latency, right?
However, since I already discarded the first 20 latencies before computing the mean, it didn’t affect the results.

Actually, I printed the statistics already for some splits of the model, I just didn’t know what to look for. I’ll post one for each case (Conv2d, AvgPool, ReLU outputs), it could be useful to compare them. I tried to format it in a way that is easier to read:

  • Conv2d output (client executed all layers except the last 3, encrypted_MB=0.57 -> t_encrypt= approx. 19 sec)
{'size_of_secret_keys': 347056, 'size_of_bootstrap_keys': 4668260352, 'size_of_keyswitch_keys': 1039138816, 
'p_error': 7.726075413135422e-08, 'global_p_error': 0.002014900278258791, 'complexity': 22134345773056.0, 
'size_of_inputs': 602112, 'size_of_outputs': 2621520, 

'programmable_bootstrap_count': 27136, 'programmable_bootstrap_count_per_parameter': {
BootstrapKeyParam(polynomial_size=8192, glwe_dimension=1, input_lwe_dimension=903, level=2, base_log=15, variance=4.70197740328915e-38): 25088, 
BootstrapKeyParam(polynomial_size=8192, glwe_dimension=1, input_lwe_dimension=593, level=2, base_log=15, variance=4.70197740328915e-38): 1536, 
BootstrapKeyParam(polynomial_size=32768, glwe_dimension=1, input_lwe_dimension=926, level=4, base_log=9, variance=4.70197740328915e-38): 512}, 
'programmable_bootstrap_count_per_tag': {}, 'programmable_bootstrap_count_per_tag_per_parameter': {}, 

'key_switch_count': 27136, 'key_switch_count_per_parameter': {
KeyswitchKeyParam(level=4, base_log=4, variance=4.372764418454018e-13): 25088, 
KeyswitchKeyParam(level=5, base_log=2, variance=2.7338004748751887e-08): 1536, 
KeyswitchKeyParam(level=10, base_log=2, variance=1.9271833105393818e-13): 512}, 
'key_switch_count_per_tag': {}, 'key_switch_count_per_tag_per_parameter': {}, 

'packing_key_switch_count': 0, 'packing_key_switch_count_per_parameter': {}, 
'packing_key_switch_count_per_tag': {}, 'packing_key_switch_count_per_tag_per_parameter': {}, 

'clear_addition_count': 28672, 'clear_addition_count_per_parameter': {
LweSecretKeyParam(dimension=8192): 28672}, 
'clear_addition_count_per_tag': {}, 'clear_addition_count_per_tag_per_parameter': {}, 

'encrypted_addition_count': 12851712, 'encrypted_addition_count_per_parameter': {
LweSecretKeyParam(dimension=8192): 12846592, 
LweSecretKeyParam(dimension=32768): 5120}, 
'encrypted_addition_count_per_tag': {}, 'encrypted_addition_count_per_tag_per_parameter': {}, 

'clear_multiplication_count': 12851712, 'clear_multiplication_count_per_parameter':  {
LweSecretKeyParam(dimension=8192): 12846592, 
LweSecretKeyParam(dimension=32768): 5120}, 
'clear_multiplication_count_per_tag': {}, 'clear_multiplication_count_per_tag_per_parameter': {}, 

'encrypted_negation_count': 1536, 'encrypted_negation_count_per_parameter': {
LweSecretKeyParam(dimension=8192): 1536}, 
'encrypted_negation_count_per_tag': {}, 'encrypted_negation_count_per_tag_per_parameter': {}}
  • AvgPool output (client executed all layers except the last 6, encrypted_MB=0.57 -> t_encrypt= approx. 5 sec)
{'size_of_secret_keys': 294552, 'size_of_bootstrap_keys': 20278640640, 'size_of_keyswitch_keys': 1730142208, 
'p_error': 9.473756358082791e-08, 'global_p_error': 0.006167934496865424, 'complexity': 1210812549475840.0, 
'size_of_inputs': 602112, 'size_of_outputs': 2621520, 

'programmable_bootstrap_count': 453632, 'programmable_bootstrap_count_per_parameter': {
BootstrapKeyParam(polynomial_size=512, glwe_dimension=5, input_lwe_dimension=611, level=2, base_log=16, variance=4.70197740328915e-38): 402944, 
BootstrapKeyParam(polynomial_size=32768, glwe_dimension=1, input_lwe_dimension=880, level=21, base_log=2, variance=4.70197740328915e-38): 50688},
'programmable_bootstrap_count_per_tag': {}, 'programmable_bootstrap_count_per_tag_per_parameter': {}, 

'key_switch_count': 503808, 'key_switch_count_per_parameter': {
KeyswitchKeyParam(level=5, base_log=2, variance=1.4397555853147048e-08): 402944, 
KeyswitchKeyParam(level=18, base_log=1, variance=9.921769535221788e-13): 50688, 
KeyswitchKeyParam(level=2, base_log=21, variance=4.70197740328915e-38): 50176}, 
'key_switch_count_per_tag': {}, 'key_switch_count_per_tag_per_parameter': {}, 

'packing_key_switch_count': 0, 'packing_key_switch_count_per_parameter': {}, 
'packing_key_switch_count_per_tag': {}, 'packing_key_switch_count_per_tag_per_parameter': {}, 

'clear_addition_count': 906752, 'clear_addition_count_per_parameter': {
LweSecretKeyParam(dimension=2560): 906752}, 
'clear_addition_count_per_tag': {}, 'clear_addition_count_per_tag_per_parameter': {}, 

'encrypted_addition_count': 244464128, 'encrypted_addition_count_per_parameter': {
LweSecretKeyParam(dimension=2560): 244459008, 
LweSecretKeyParam(dimension=32768): 5120}, 
'encrypted_addition_count_per_tag': {}, 'encrypted_addition_count_per_tag_per_parameter': {}, 

'clear_multiplication_count': 244464128, 'clear_multiplication_count_per_parameter': {
LweSecretKeyParam(dimension=2560): 244459008, 
LweSecretKeyParam(dimension=32768): 5120}, 
'clear_multiplication_count_per_tag': {}, 'clear_multiplication_count_per_tag_per_parameter': {}, 

'encrypted_negation_count': 402944, 'encrypted_negation_count_per_parameter': {
LweSecretKeyParam(dimension=2560): 402944}, 
'encrypted_negation_count_per_tag': {}, 'encrypted_negation_count_per_tag_per_parameter': {}}
  • ReLU output (client executed all layers except the last 12, encrypted_MB=1.32 -> t_encrypt= approx. 73 sec)
{'size_of_secret_keys': 470032, 'size_of_bootstrap_keys': 15955197952, 'size_of_keyswitch_keys': 4467261440, 
'p_error': 8.338687633247801e-08, 'global_p_error': 0.014227299033742025, 'complexity': 1609237816835584.0, 
'size_of_inputs': 1382400, 'size_of_outputs': 2621520, 

'programmable_bootstrap_count': 892672, 'programmable_bootstrap_count_per_parameter': {
BootstrapKeyParam(polynomial_size=16384, glwe_dimension=1, input_lwe_dimension=906, level=6, base_log=6, variance=4.70197740328915e-38): 12544,
BootstrapKeyParam(polynomial_size=512, glwe_dimension=4, input_lwe_dimension=638, level=2, base_log=16, variance=8.442253112932959e-31): 754176,
BootstrapKeyParam(polynomial_size=32768, glwe_dimension=1, input_lwe_dimension=873, level=13, base_log=3, variance=4.70197740328915e-38): 100864, 
BootstrapKeyParam(polynomial_size=4096, glwe_dimension=1, input_lwe_dimension=1041, level=5, base_log=8, variance=4.70197740328915e-38): 25088},
'programmable_bootstrap_count_per_tag': {}, 'programmable_bootstrap_count_per_tag_per_parameter': {},

'key_switch_count': 1005568, 'key_switch_count_per_parameter': {
KeyswitchKeyParam(level=9, base_log=2, variance=3.9295523001771657e-13): 12544, 
KeyswitchKeyParam(level=3, base_log=12, variance=8.442253112932959e-31): 87808, 
KeyswitchKeyParam(level=4, base_log=3, variance=5.502647693575897e-09): 754176, 
KeyswitchKeyParam(level=18, base_log=1, variance=1.273169281266184e-12): 100864, 
KeyswitchKeyParam(level=2, base_log=8, variance=3.204481389236339e-15): 25088, 
KeyswitchKeyParam(level=2, base_log=17, variance=8.442253112932959e-31): 25088}, 
'key_switch_count_per_tag': {}, 'key_switch_count_per_tag_per_parameter': {}, 

'packing_key_switch_count': 0, 'packing_key_switch_count_per_parameter': {}, 
'packing_key_switch_count_per_tag': {}, 'packing_key_switch_count_per_tag_per_parameter': {}, 

'clear_addition_count': 1709568, 'clear_addition_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 1709568}, 
'clear_addition_count_per_tag': {}, 'clear_addition_count_per_tag_per_parameter': {}, 

'encrypted_addition_count': 443913728, 'encrypted_addition_count_per_parameter': {
LweSecretKeyParam(dimension=16384): 12845056, 
LweSecretKeyParam(dimension=2048): 418218496, 
LweSecretKeyParam(dimension=32768): 12850176}, 
'encrypted_addition_count_per_tag': {}, 'encrypted_addition_count_per_tag_per_parameter': {}, 

'clear_multiplication_count': 443913728, 'clear_multiplication_count_per_parameter': {
LweSecretKeyParam(dimension=16384): 12845056, 
LweSecretKeyParam(dimension=2048): 418218496, 
LweSecretKeyParam(dimension=32768): 12850176}, 
'clear_multiplication_count_per_tag': {}, 'clear_multiplication_count_per_tag_per_parameter': {}, 

'encrypted_negation_count': 754176, 'encrypted_negation_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 754176}, 
'encrypted_negation_count_per_tag': {}, 'encrypted_negation_count_per_tag_per_parameter': {}}
  • ReLU output (client executed all layers except the last 14, encrypted_MB=1.32 -> t_encrypt= approx. 9 sec)
{'size_of_secret_keys': 469976, 'size_of_bootstrap_keys': 15995994112, 'size_of_keyswitch_keys': 5835653120,
'p_error': 9.837142123243636e-08, 'global_p_error': 0.020710787219196534, 'complexity': 2469920417048576.0, 
'size_of_inputs': 1382400, 'size_of_outputs': 2621520,

'programmable_bootstrap_count': 1328384, 'programmable_bootstrap_count_per_parameter': {
BootstrapKeyParam(polynomial_size=512, glwe_dimension=4, input_lwe_dimension=611, level=2, base_log=16, variance=8.442253112932959e-31): 1132288,
BootstrapKeyParam(polynomial_size=32768, glwe_dimension=1, input_lwe_dimension=873, level=13, base_log=3, variance=4.70197740328915e-38): 158464, 
BootstrapKeyParam(polynomial_size=16384, glwe_dimension=1, input_lwe_dimension=926, level=6, base_log=6, variance=4.70197740328915e-38): 12544,
BootstrapKeyParam(polynomial_size=4096, glwe_dimension=1, input_lwe_dimension=1041, level=5, base_log=8, variance=4.70197740328915e-38): 25088}, 
'programmable_bootstrap_count_per_tag': {}, 'programmable_bootstrap_count_per_tag_per_parameter': {},

'key_switch_count': 1441280, 'key_switch_count_per_parameter': {
KeyswitchKeyParam(level=5, base_log=2, variance=1.4397555853147048e-08): 1132288, 
KeyswitchKeyParam(level=18, base_log=1, variance=1.273169281266184e-12): 158464, 
KeyswitchKeyParam(level=10, base_log=2, variance=1.9271833105393818e-13): 12544,
KeyswitchKeyParam(level=3, base_log=12, variance=8.442253112932959e-31): 87808, 
KeyswitchKeyParam(level=2, base_log=8, variance=3.204481389236339e-15): 25088,
KeyswitchKeyParam(level=2, base_log=17, variance=8.442253112932959e-31): 25088}, 
'key_switch_count_per_tag': {}, 'key_switch_count_per_tag_per_parameter': {}, 

'packing_key_switch_count': 0, 'packing_key_switch_count_per_parameter': {},
'packing_key_switch_count_per_tag': {}, 'packing_key_switch_count_per_tag_per_parameter': {}, 

'clear_addition_count': 2580992, 'clear_addition_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 2580992}, 
'clear_addition_count_per_tag': {}, 'clear_addition_count_per_tag_per_parameter': {}, 

'encrypted_addition_count': 577002240, 'encrypted_addition_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 551307008, 
LweSecretKeyParam(dimension=32768): 25695232}, 
'encrypted_addition_count_per_tag': {}, 'encrypted_addition_count_per_tag_per_parameter': {}, 

'clear_multiplication_count': 577002240, 'clear_multiplication_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 551307008, 
LweSecretKeyParam(dimension=32768): 25695232}, 
'clear_multiplication_count_per_tag': {}, 'clear_multiplication_count_per_tag_per_parameter': {}, 

'encrypted_negation_count': 1132288, 'encrypted_negation_count_per_parameter': {
LweSecretKeyParam(dimension=2048): 1132288}, 
'encrypted_negation_count_per_tag': {}, 'encrypted_negation_count_per_tag_per_parameter': {}}

What should we be looking at in the statistics?

As you told me, I separated the key generation step before the encryption. The first latency is significantly faster, this is because the key was generated in the first latency, right?

Yes that’s right.

We don’t have much info on the encryption complexity unfortunately.

Maybe we will have more luck with the server.zip which should contain all the necessary crypto parameters:

fhe_circuit.server.save("server.zip")

Which file are we looking for?
I unzipped server.zip and I got

  • compilation_feedback.json
  • composition_rules.json
  • sharedlib.so
  • client.specs.json
  • program_info.concrete.params.json
  • is_simulated

client.specs.json is what we need to compare between your different circuit. It contains the crypto parameters the client needs to use to create the private key and encrypt.

If I’m not wrong, the encryption is performed with the secret key. Therefore, here is the data on the secret key based on the different cases that I mentioned above.

  • Conv2d output (client executed all layers except the last 3, encrypted_MB=0.57 -> t_encrypt= approx. 19 sec)
"lweSecretKeys":[
   {
       "id":0,
        "params":{
           "lweDimension":8192,
            "integerPrecision":64,
            "keyType":"binary"
        }
   },
   {
       "id":1,
       "params":{
          "lweDimension":903,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":2,
       "params":{
          "lweDimension":593,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":3,
       "params":{
          "lweDimension":32768,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":4,
       "params":{
          "lweDimension":926,
          "integerPrecision":64,
          "keyType":"binary"
       }
    }
],

Additional info in the circuit[‘inputs’] section:

"encryption":{
    "keyId":0,
    "variance":4.70197740328915e-38,
    "lweDimension":8192,
    "modulus":{"modulus":{"native":{}}}
},
  • AvgPool output (client executed all layers except the last 6, encrypted_MB=0.57 -> t_encrypt= approx. 5 sec)
"lweSecretKeys":[
   {
       "id":0,
        "params":{
           "lweDimension":2560,
            "integerPrecision":64,
            "keyType":"binary"
        }
   },
   {
       "id":1,
       "params":{
          "lweDimension":611,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":2,
       "params":{
          "lweDimension":32728,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":3,
       "params":{
          "lweDimension":880,
          "integerPrecision":64,
          "keyType":"binary"
       }
   }
],

Additional info in the circuit[‘inputs’] section:

"encryption":{
    "keyId":0,
    "variance":4.70197740328915e-38,
    "lweDimension":2560,
    "modulus":{"modulus":{"native":{}}}
},
  • ReLU output (client executed all layers except the last 12, encrypted_MB=1.32 -> t_encrypt= approx. 73 sec)
"lweSecretKeys":[
   {
       "id":0,
        "params":{
           "lweDimension":16384,
            "integerPrecision":64,
            "keyType":"binary"
        }
   },
   {
       "id":1,
       "params":{
          "lweDimension":906,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":2,
       "params":{
          "lweDimension":2048,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":3,
       "params":{
          "lweDimension":638,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":4,
       "params":{
          "lweDimension":32768,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":5,
       "params":{
          "lweDimension":873,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":6,
       "params":{
          "lweDimension":4096,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":7,
       "params":{
          "lweDimension":1041,
          "integerPrecision":64,
          "keyType":"binary"
       }
    }
],

Additional info in the circuit[‘inputs’] section:

"encryption":{
    "keyId":0,
    "variance":4.70197740328915e-38,
    "lweDimension":16384,
    "modulus":{"modulus":{"native":{}}}
},
  • ReLU ReLU output (client executed all layers except the last 14, encrypted_MB=1.32 -> t_encrypt= approx. 9 sec)
"lweSecretKeys":[
   {
       "id":0,
        "params":{
           "lweDimension":2048,
            "integerPrecision":64,
            "keyType":"binary"
        }
   },
   {
       "id":1,
       "params":{
          "lweDimension":611,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":2,
       "params":{
          "lweDimension":32768,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":3,
       "params":{
          "lweDimension":873,
          "integerPrecision":64,
          "keyType":"binary"
       }
   },
   {
       "id":4,
       "params":{
          "lweDimension":16384,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":5,
       "params":{
          "lweDimension":926,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":6,
       "params":{
          "lweDimension":4096,
          "integerPrecision":64,
          "keyType":"binary"
       }
    },
   {
       "id":7,
       "params":{
          "lweDimension":1041,
          "integerPrecision":64,
          "keyType":"binary"
       }
    }
],

Additional info in the circuit[‘inputs’] section:

"encryption":{
    "keyId":0,
    "variance":8.4422531129329586e-31,
    "lweDimension":2048,
    "modulus":{"modulus":{"native":{}}}
},

What are we looking for?
(If you need any more information, I’ll add it)