Hi @luis. Thank you for your immediate response!
The functionality is about varying the learning rate to the Linear Regression model during the training phase instead of having a constant one. Unfortunately I cannot share a link because the repo is private, but I can post some code here.
Forward method in torch model:
def forward(
self,
features: torch.Tensor,
targets: torch.Tensor,
weights: torch.Tensor,
bias: torch.Tensor,
lr: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# (1, batch_size, n_targets)
y_pred = features @ weights + bias
# Compute the gradients
# (1, batch_size, n_targets)
dz = y_pred - targets
# (1, n_features, n_targets)
dw = (features.transpose(1, 2) @ dz) / dz.size(1)
# (1, 1, n_targets)
db = (dz.sum(dim=1, keepdim=True)) / dz.size(1)
weights -= lr * dw
bias -= lr * db
# (1, n_features, n_targets), (1, 1, n_targets)
return weights, bias
def _build_training_compile_sets_with_lr(
lbls, x_min, x_max, w_range, b_range, lr_range, batch_size, fit_intercept=True
):
"""
Create compile set.
"""
xs = [x_min, x_max, np.zeros(x_min.shape)]
combinations = list(
itertools.product(lbls, xs, w_range, b_range, lr_range)
)
compile_size = len(combinations)
n_targets = 1
n_features = x_min.shape[0]
# Generate input / target / weight / bias values:
# - x_compile_set
# - y_compile_set
# - w_compile_set
# - b_compile_set
x_compile_set = np.empty((compile_size, batch_size, n_features))
y_compile_set = np.empty((compile_size, batch_size, n_targets))
w_compile_set = np.empty((compile_size, n_features, n_targets))
b_compile_set = np.empty((compile_size, 1, n_targets))
lr_compile_set = np.empty((compile_size, 1))
# compile_set : Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
compile_set = (
x_compile_set,
y_compile_set,
w_compile_set,
b_compile_set,
lr_compile_set,
)
# Bound values are hard-coded in order to make sure that the circuit never overflows
for index, (label, x_value, coef_value, bias_value, lr_value) in enumerate(
combinations
):
compile_set[0][index] = x_value
compile_set[1][index] = label
compile_set[2][index] = coef_value
if not fit_intercept:
bias_value *= 0.0
compile_set[3][index] = bias_value
compile_set[4][index] = lr_value
return compile_set
This is the error I get:
Compiling training circuit ...
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[15], line 1
----> 1 model.prepare_training(X_range, w_range, b_range, lr_range)
File ~/Documents/Projects/trustee/PriveXAI/encrypted_ml/linear_regression.py:231, in EncryptedLinearRegression.prepare_training(self, X_range, w_range, b_range, lr_range, key)
228 compile_kw = self.compile_kw
230 # Build and compile the training quantized module
--> 231 quantized_module_training = self.get_training_quantized_module(
232 X_range=X_range,
233 w_range=w_range,
234 b_range=b_range,
235 lr_range=lr_range,
236 training_kw=training_kw,
237 compile_kw=compile_kw,
238 )
240 n_features = training_kw["n_features"]
242 self.weights, self.bias = _init_regression_training_weights(
243 w_range,
244 b_range,
(...)
247 fit_intercept=training_kw["fit_intercept"],
248 )
File ~/Documents/Projects/trustee/PriveXAI/encrypted_ml/linear_regression.py:316, in EncryptedLinearRegression.get_training_quantized_module(self, X_range, w_range, b_range, lr_range, training_kw, compile_kw)
306 inputs_encryption_status = [
307 "encrypted",
308 "encrypted",
(...)
311 "clear",
312 ]
314 start = time.time()
--> 316 quantized_module_training = compile_torch_model(
317 trainer,
318 compile_set,
319 reduce_sum_copy=True,
320 inputs_encryption_status=inputs_encryption_status,
321 **compile_kw,
322 )
324 # quantized_module_training = build_quantized_module(
325 # trainer,
326 # compile_set,
(...)
329 # reduce_sum_copy=True,
330 # )
332 end = time.time()
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/compile.py:283, in compile_torch_model(torch_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
271 assert_true(
272 isinstance(torch_model, torch.nn.Module),
273 "The compile_torch_model function must be called on a torch.nn.Module",
274 )
276 assert_false(
277 has_any_qnn_layers(torch_model),
278 "The compile_torch_model was called on a torch.nn.Module that contains "
279 "Brevitas quantized layers. These models must be imported "
280 "using compile_brevitas_qat_model instead.",
281 )
--> 283 return _compile_torch_or_onnx_model(
284 torch_model,
285 torch_inputset,
286 import_qat,
287 configuration=configuration,
288 artifacts=artifacts,
289 show_mlir=show_mlir,
290 n_bits=n_bits,
291 rounding_threshold_bits=rounding_threshold_bits,
292 p_error=p_error,
293 global_p_error=global_p_error,
294 verbose=verbose,
295 inputs_encryption_status=inputs_encryption_status,
296 reduce_sum_copy=reduce_sum_copy,
297 )
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/compile.py:182, in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
177 inputset_as_numpy_tuple = tuple(
178 convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
179 )
181 # Build the quantized module
--> 182 quantized_module = build_quantized_module(
183 model=model,
184 torch_inputset=inputset_as_numpy_tuple,
185 import_qat=import_qat,
186 n_bits=n_bits,
187 rounding_threshold_bits=rounding_threshold_bits,
188 reduce_sum_copy=reduce_sum_copy,
189 )
191 # Check that p_error or global_p_error is not set in both the configuration and in the direct
192 # parameters
193 check_there_is_no_p_error_options_in_configuration(configuration)
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/compile.py:122, in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits, reduce_sum_copy)
116 post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
118 # Build the quantized module
119 # FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
120 # only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
121 # on the inputset we sent in the NumpyModule.
--> 122 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
123 # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
124 if reduce_sum_copy:
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/post_training.py:599, in ONNXConverter.quantize_module(self, *calibration_data)
596 # First transform all parameters to their quantized version
597 self._quantize_params()
--> 599 self._quantize_layers(*calibration_data)
601 # Create quantized module from self.quant_layers_dict
602 quantized_module = QuantizedModule(
603 ordered_module_input_names=(
604 graph_input.name for graph_input in self.numpy_model.onnx_model.graph.input
(...)
610 onnx_model=self.numpy_model.onnx_model,
611 )
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/post_training.py:491, in ONNXConverter._quantize_layers(self, *input_calibration_data)
484 assert_true(
485 op_type in ONNX_OPS_TO_QUANTIZED_IMPL,
486 f"{op_type} can't be found in {ONNX_OPS_TO_QUANTIZED_IMPL}",
487 )
489 # Note that the output of a quantized op could be a network output
490 # Thus the quantized op outputs are quantized to the network output bit-width
--> 491 quantized_op_instance = quantized_op_class(
492 self.n_bits_model_outputs,
493 node.name,
494 node_integer_inputs,
495 curr_cst_inputs,
496 self._get_input_quant_opts(curr_calibration_data, quantized_op_class),
497 **attributes,
498 )
500 # Determine if this op computes a tensor that is a graph output, i.e., a tensor
501 # that will be decrypted and de-quantized in the clear
502 quantized_op_instance.produces_graph_output = output_name in graph_output_names
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/base_quantized_op.py:840, in QuantizedOpUnivariateOfEncrypted.__init__(self, n_bits_output, op_instance_name, int_input_names, constant_inputs, input_quant_opts, **attrs)
826 super().__init__(
827 n_bits_output,
828 op_instance_name,
(...)
832 **attrs,
833 ) # type: ignore
835 # We do not support this type of operation between encrypted tensors, only between:
836 # - encrypted tensors and float constants
837 # - tensors that are produced by a unique integer tensor
838 # If this operation is applied between two constants
839 # it should be optimized out by the constant folding procedure
--> 840 assert_true(
841 self.can_fuse() or (constant_inputs is not None and len(constant_inputs) == 1),
842 "Do not support this type of operation between encrypted tensors",
843 )
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/common/debugging/custom_assert.py:40, in assert_true(condition, on_error_msg, error_type)
28 def assert_true(
29 condition: bool, on_error_msg: str = "", error_type: Type[Exception] = AssertionError
30 ):
31 """Provide a custom assert to check that the condition is True.
32
33 Args:
(...)
38
39 """
---> 40 _custom_assert(condition, on_error_msg, error_type)
File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/common/debugging/custom_assert.py:25, in _custom_assert(condition, on_error_msg, error_type)
8 """Provide a custom assert which is kept even if the optimized Python mode is used.
9
10 See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation
(...)
21
22 """
24 if not condition:
---> 25 raise error_type(on_error_msg)
AssertionError: Do not support this type of operation between encrypted tensors