Hey,
So I’m having some problems with the Zama Concrete ML linear regression model. I tried to build a minimal failing example here:
from sklearn.linear_model import LinearRegression as SklearnLinearRegression
from concrete.ml.sklearn import LinearRegression as ConcreteLinearRegression
import numpy as np
def generate_dataset(n, t):
x_train = np.array([np.float64(i + 1) for i in range(n)]).reshape(n, 1)
y_train = np.random.default_rng().uniform(0.0, 100.0, n)
x_test = np.array([np.float64(i + n + 1) for i in range(t)]).reshape(n, 1)
return x_train, y_train, x_test
x_train, y_train, x_test = generate_dataset(2, 2)
print("X train: ", " ".join([str(value[0]) for value in x_train]))
print("Y train: ", " ".join([str(value) for value in y_train]))
print("X test: ", " ".join([str(value[0]) for value in x_test]))
plaintext_model = SklearnLinearRegression()
plaintext_model.fit(x_train, y_train)
concrete_model = ConcreteLinearRegression(n_bits=16)
concrete_model.fit(x_train, y_train)
y_pred_plaintext = plaintext_model.predict(x_test)
additional_values = np.array([0.,1.,2.,3.,4.,5.,6.,10.,12.]).reshape(-1,1)
input_range = np.concatenate((x_test, additional_values), axis=0)
concrete_model.compile(input_range, verbose=True)
y_pred_concrete = concrete_model.predict(x_test, fhe="execute")
print(
"Plaintext coef and intercept: ",
plaintext_model.coef_[0],
plaintext_model.intercept_,
)
print(
"Concrete coef and intercept: ", concrete_model.coef_[0], concrete_model.intercept_
)
print("Plaintext predictions: ", " ".join([str(value) for value in y_pred_plaintext]))
print("Concrete predictions: ", " ".join([str(value[0]) for value in y_pred_concrete]))
So here we generate just two (X,Y) points in the training dataset so we can build a perfect linear regression that fits the two points at X=1.0 and 2.0. We then train to linear regression models – a regular sklearn one and the one from Concrete ML; then run inference on X=3.0 and 4.0.
The model from sklearn works as expected. The one from Concrete ML just repeats the the last known Y, and I can’t figure out why.
I thought initially, the issue is with not enough points in X_calibrate
when I compile the circuit, but adding more (via additional_values
) doesn’t seem to fix the issue. Different n_bits
values have no effect either. Moreover, the issue persists not only with fhe="execute"
but also when it is set to simulate
and even disable
.
I’m out of ideas now, banging my head Any help would be greatly appreciated!
Here’s the output of the program above:
X train: 1.0 2.0
Y train: 12.691858014601243 60.26780285759041
X test: 3.0 4.0
Computation Graph
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = q_X # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%1 = [[1]] # ClearTensor<uint1, shape=(1, 1)> ∈ [1, 1]
%2 = matmul(%0, %1) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%3 = sum(%0, axis=1, keepdims=True) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%4 = 0 # ClearScalar<uint1> ∈ [0, 0]
%5 = multiply(%4, %3) # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 0]
%6 = subtract(%2, %5) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%7 = [[-146355]] # ClearTensor<int19, shape=(1, 1)> ∈ [-146355, -146355]
%8 = add(%6, %7) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-179123, -113588]
return %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Constraints
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0:
%0 >= 16
%1:
%1 >= 1
%2:
%2 >= 16
%0 == %1
%1 == %2
%3:
%3 >= 16
%0 == %3
%4:
%4 >= 1
%5:
%5 >= 1
%4 == %3
%3 == %5
%6:
%6 >= 16
%2 == %5
%5 == %6
%7:
%7 >= 19
%8:
%8 >= 19
%6 == %7
%7 == %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Assignments
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = 19
%1 = 19
%2 = 19
%3 = 19
%4 = 19
%5 = 19
%6 = 19
%7 = 19
%8 = 19
max = 19
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Assigned Computation Graph
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = q_X # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%1 = [[1]] # ClearTensor<uint20, shape=(1, 1)> ∈ [1, 1]
%2 = matmul(%0, %1) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%3 = sum(%0, axis=1, keepdims=True) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%4 = 0 # ClearScalar<uint20> ∈ [0, 0]
%5 = multiply(%4, %3) # EncryptedTensor<uint19, shape=(1, 1)> ∈ [0, 0]
%6 = subtract(%2, %5) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%7 = [[-146355]] # ClearTensor<int20, shape=(1, 1)> ∈ [-146355, -146355]
%8 = add(%6, %7) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-179123, -113588]
return %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Optimizer
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
### Optimizer display
--- Circuit
19 bits integers
0 manp (maxi log2 norm2)
--- User config
9.094947e-13 error per pbs call
1.000000e+00 error per circuit call
-- Solution correctness
For each pbs call: 1/2147483647, p_error (4.272044e-13)
For the full circuit: 1/2147483647 global_p_error(4.272044e-13)
--- Complexity for the full circuit
1.000000e+00 Millions Operations
-- Circuit Solution
CircuitSolution {
circuit_keys: CircuitKeys {
secret_keys: [
SecretLweKey {
identifier: 0,
polynomial_size: 1,
glwe_dimension: 1009,
description: "big representation",
},
],
keyswitch_keys: [],
bootstrap_keys: [],
conversion_keyswitch_keys: [],
circuit_bootstrap_keys: [],
private_functional_packing_keys: [],
},
instructions_keys: [],
crt_decomposition: [],
complexity: 1009.0,
p_error: 4.2720437586522667e-13,
global_p_error: 4.2720437586522667e-13,
is_feasible: true,
error_msg: "",
}###
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Statistics
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
size_of_secret_keys: 8072
size_of_bootstrap_keys: 0
size_of_keyswitch_keys: 0
size_of_inputs: 8080
size_of_outputs: 8080
p_error: 4.2720437586522667e-13
global_p_error: 4.2720437586522667e-13
complexity: 1009.0
programmable_bootstrap_count: 0
key_switch_count: 0
packing_key_switch_count: 0
clear_addition_count: 1
clear_addition_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 1
}
encrypted_addition_count: 2
encrypted_addition_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 2
}
clear_multiplication_count: 0
encrypted_negation_count: 1
encrypted_negation_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 1
}
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Plaintext coef and intercept: 47.57594484298915 -34.88408682838791
Concrete coef and intercept: 47.57594484298915 -34.88408682838791
Plaintext predictions: 107.84374770057956 155.41969254356871
Concrete predictions: 60.26794520447507 60.26794520447507