Using scalars as inputs to circuits from compiled torch models

Hi, at first congrats for your great work.

Is it possible to have a scalar as input to a circuit that is compiled from a torch model to support a simple scalar-vector multiplication and if yes what would be the recommended way to do it?

I have tried several things. I add a scalar input to my torch model .forward method, set up my compile sets and run the circuit. I have tried combinations like scalar-vector and vector-vector elementwise
multiplication, but I get an AssertionError on a not supported operation.

The only thing that works is passing a diagonal matrix with the scalar on the diagonal and performing a matrix-vector operation but this feels weird and is kind of an overkill, and also produces strange, not-correct results.

Some context on what I am working on.

I am experimenting with training simple models in the encrypted domain following your example of the logistic regressor. I am working on a linear regression model.

Since I was not interested in interoperability with sklearn and I was mostly interested in the training part I created a class EncryptedLinearRegression stripped down from sklearn-related attributes & mechanics.

The class performs similar steps with your Logistic Regression training example:

  • Building compile sets for training
  • Initializing a torch model
  • Compiling torch model
  • Train on fixed batch-size batchesm through the QuantizedModule.forward method.

Hello @Akis , thanks for your kind words.

You should be able to pass a scalar to a circuit but we might have a bug for it since we didn’t encounter any model with this need yet. You might also want to try with an array of shape (1,).

If you can provide us a link to your implementation it would help us help you :slightly_smiling_face:

You should be able to build the training of a linear regressor in FHE by modifying the current implementation of the SGDClassifier training indeed. One thing that should be noted is that you need some prior information on the bounds of the target y since that will need to be quantized too.

Hi @luis. Thank you for your immediate response!

The functionality is about varying the learning rate to the Linear Regression model during the training phase instead of having a constant one. Unfortunately I cannot share a link because the repo is private, but I can post some code here.

Forward method in torch model:

    def forward(
        features: torch.Tensor,
        targets: torch.Tensor,
        weights: torch.Tensor,
        bias: torch.Tensor,
        lr: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # (1, batch_size, n_targets)
        y_pred = features @ weights + bias

        # Compute the gradients
        # (1, batch_size, n_targets)
        dz = y_pred - targets

        # (1, n_features, n_targets)
        dw = (features.transpose(1, 2) @ dz) / dz.size(1)

        # (1, 1, n_targets)
        db = (dz.sum(dim=1, keepdim=True)) / dz.size(1)

        weights -= lr * dw

        bias -= lr * db

        # (1, n_features, n_targets), (1, 1, n_targets)
        return weights, bias
def _build_training_compile_sets_with_lr(
    lbls, x_min, x_max, w_range, b_range, lr_range, batch_size, fit_intercept=True
    Create compile set.
    xs = [x_min, x_max, np.zeros(x_min.shape)]

    combinations = list(
        itertools.product(lbls, xs, w_range, b_range, lr_range)

    compile_size = len(combinations)

    n_targets = 1
    n_features = x_min.shape[0]

    # Generate input / target / weight / bias values:
    # - x_compile_set
    # - y_compile_set
    # - w_compile_set
    # - b_compile_set
    x_compile_set = np.empty((compile_size, batch_size, n_features))
    y_compile_set = np.empty((compile_size, batch_size, n_targets))
    w_compile_set = np.empty((compile_size, n_features, n_targets))
    b_compile_set = np.empty((compile_size, 1, n_targets))
    lr_compile_set = np.empty((compile_size, 1))

    # compile_set : Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
    compile_set = (

    # Bound values are hard-coded in order to make sure that the circuit never overflows
    for index, (label, x_value, coef_value, bias_value, lr_value) in enumerate(
        compile_set[0][index] = x_value
        compile_set[1][index] = label
        compile_set[2][index] = coef_value
        if not fit_intercept:
            bias_value *= 0.0
        compile_set[3][index] = bias_value
        compile_set[4][index] = lr_value

    return compile_set

This is the error I get:

Compiling training circuit ...
AssertionError                            Traceback (most recent call last)
Cell In[15], line 1
----> 1 model.prepare_training(X_range, w_range, b_range, lr_range)

File ~/Documents/Projects/trustee/PriveXAI/encrypted_ml/, in EncryptedLinearRegression.prepare_training(self, X_range, w_range, b_range, lr_range, key)
    228 compile_kw = self.compile_kw
    230 # Build and compile the training quantized module
--> 231 quantized_module_training = self.get_training_quantized_module(
    232     X_range=X_range,
    233     w_range=w_range,
    234     b_range=b_range,
    235     lr_range=lr_range,
    236     training_kw=training_kw,
    237     compile_kw=compile_kw,
    238 )
    240 n_features = training_kw["n_features"]
    242 self.weights, self.bias = _init_regression_training_weights(
    243     w_range,
    244     b_range,
    247     fit_intercept=training_kw["fit_intercept"],
    248 )

File ~/Documents/Projects/trustee/PriveXAI/encrypted_ml/, in EncryptedLinearRegression.get_training_quantized_module(self, X_range, w_range, b_range, lr_range, training_kw, compile_kw)
    306 inputs_encryption_status = [
    307     "encrypted",
    308     "encrypted",
    311     "clear",
    312 ]
    314 start = time.time()
--> 316 quantized_module_training = compile_torch_model(
    317     trainer,
    318     compile_set,
    319     reduce_sum_copy=True,
    320     inputs_encryption_status=inputs_encryption_status,
    321     **compile_kw,
    322 )
    324 # quantized_module_training = build_quantized_module(
    325 #     trainer,
    326 #     compile_set,
    329 #     reduce_sum_copy=True,
    330 # )
    332 end = time.time()

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/, in compile_torch_model(torch_model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
    271 assert_true(
    272     isinstance(torch_model, torch.nn.Module),
    273     "The compile_torch_model function must be called on a torch.nn.Module",
    274 )
    276 assert_false(
    277     has_any_qnn_layers(torch_model),
    278     "The compile_torch_model was called on a torch.nn.Module that contains "
    279     "Brevitas quantized layers. These models must be imported "
    280     "using compile_brevitas_qat_model instead.",
    281 )
--> 283 return _compile_torch_or_onnx_model(
    284     torch_model,
    285     torch_inputset,
    286     import_qat,
    287     configuration=configuration,
    288     artifacts=artifacts,
    289     show_mlir=show_mlir,
    290     n_bits=n_bits,
    291     rounding_threshold_bits=rounding_threshold_bits,
    292     p_error=p_error,
    293     global_p_error=global_p_error,
    294     verbose=verbose,
    295     inputs_encryption_status=inputs_encryption_status,
    296     reduce_sum_copy=reduce_sum_copy,
    297 )

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/, in _compile_torch_or_onnx_model(model, torch_inputset, import_qat, configuration, artifacts, show_mlir, n_bits, rounding_threshold_bits, p_error, global_p_error, verbose, inputs_encryption_status, reduce_sum_copy)
    177 inputset_as_numpy_tuple = tuple(
    178     convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
    179 )
    181 # Build the quantized module
--> 182 quantized_module = build_quantized_module(
    183     model=model,
    184     torch_inputset=inputset_as_numpy_tuple,
    185     import_qat=import_qat,
    186     n_bits=n_bits,
    187     rounding_threshold_bits=rounding_threshold_bits,
    188     reduce_sum_copy=reduce_sum_copy,
    189 )
    191 # Check that p_error or global_p_error is not set in both the configuration and in the direct
    192 # parameters
    193 check_there_is_no_p_error_options_in_configuration(configuration)

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/torch/, in build_quantized_module(model, torch_inputset, import_qat, n_bits, rounding_threshold_bits, reduce_sum_copy)
    116 post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
    118 # Build the quantized module
    119 # FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
    120 # only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
    121 # on the inputset we sent in the NumpyModule.
--> 122 quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
    123 # FIXME:
    124 if reduce_sum_copy:

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/, in ONNXConverter.quantize_module(self, *calibration_data)
    596 # First transform all parameters to their quantized version
    597 self._quantize_params()
--> 599 self._quantize_layers(*calibration_data)
    601 # Create quantized module from self.quant_layers_dict
    602 quantized_module = QuantizedModule(
    603     ordered_module_input_names=(
    604 for graph_input in self.numpy_model.onnx_model.graph.input
    610     onnx_model=self.numpy_model.onnx_model,
    611 )

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/, in ONNXConverter._quantize_layers(self, *input_calibration_data)
    484 assert_true(
    485     op_type in ONNX_OPS_TO_QUANTIZED_IMPL,
    486     f"{op_type} can't be found in {ONNX_OPS_TO_QUANTIZED_IMPL}",
    487 )
    489 # Note that the output of a quantized op could be a network output
    490 # Thus the quantized op outputs are quantized to the network output bit-width
--> 491 quantized_op_instance = quantized_op_class(
    492     self.n_bits_model_outputs,
    494     node_integer_inputs,
    495     curr_cst_inputs,
    496     self._get_input_quant_opts(curr_calibration_data, quantized_op_class),
    497     **attributes,
    498 )
    500 # Determine if this op computes a tensor that is a graph output, i.e., a tensor
    501 # that will be decrypted and de-quantized in the clear
    502 quantized_op_instance.produces_graph_output = output_name in graph_output_names

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/quantization/, in QuantizedOpUnivariateOfEncrypted.__init__(self, n_bits_output, op_instance_name, int_input_names, constant_inputs, input_quant_opts, **attrs)
    826 super().__init__(
    827     n_bits_output,
    828     op_instance_name,
    832     **attrs,
    833 )  # type: ignore
    835 # We do not support this type of operation between encrypted tensors, only between:
    836 # - encrypted tensors and float constants
    837 # - tensors that are produced by a unique integer tensor
    838 # If this operation is applied between two constants
    839 # it should be optimized out by the constant folding procedure
--> 840 assert_true(
    841     self.can_fuse() or (constant_inputs is not None and len(constant_inputs) == 1),
    842     "Do not support this type of operation between encrypted tensors",
    843 )

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/common/debugging/, in assert_true(condition, on_error_msg, error_type)
     28 def assert_true(
     29     condition: bool, on_error_msg: str = "", error_type: Type[Exception] = AssertionError
     30 ):
     31     """Provide a custom assert to check that the condition is True.
     33     Args:
     39     """
---> 40     _custom_assert(condition, on_error_msg, error_type)

File ~/miniconda3/envs/zama/lib/python3.10/site-packages/concrete/ml/common/debugging/, in _custom_assert(condition, on_error_msg, error_type)
      8 """Provide a custom assert which is kept even if the optimized Python mode is used.
     10 See for the documentation
     22 """
     24 if not condition:
---> 25     raise error_type(on_error_msg)

AssertionError: Do not support this type of operation between encrypted tensors

My pleasure.

Interesting use-case!

So actually we did experiment with different learning rates for the binary classifier training in FHE but what we saw is that having a learning rate different from 1.0 didn’t behave very nicely with the quantization.

Also what you are trying to do here is having the learning rate as an encrypted input, I’m not sure if this is what you really want. You could define programatically the function to use an external variable containing the learning such that you avoid considering it as an input to your circuit. That being said this approach would require a circuit per learning rate and the learning rate to not be encrypted.

The issue you are having is because even if we do support encrypted/encrypted matrix multiplication, we do not support encrypted/encrypted element-wise multiplication at the moment.

I’m adding this to our roadmap but I can’t give an ETA on the availability of this feature.

Thanks Luis!
I see what you mean. From my side, at least for the current use case, some instability is ok as the results I am getting for different learning rates and batch sizes are consistent with the theory and I am trying mostly to identify tradeoffs with respect to different configurations.

I do not need per se for the lr variable to be encrypted as it is an input during training. I was also experimenting with declaring the learning rate as a clear input by using the inputs_encryption_status argument of the compile_torch_model method, but it did not have any effect on the error above. Also it was not clear to me if I had to declare things differently when building the compile sets. Do you have any advice for following this kind of approach?

So actually we have a known issue inputs_encryption_status not being propagated properly through CML for it be used (I’ll create an issue on our side to add a comment for the time being and fix it in the future).

For training the compile set should be defined as to cover the bounds of each operation in the circuit such that the automatic quantization calibration is done properly, to avoid overflows when running in FHE.
You can check what we did for classification and just change the extreme values of the target (1.0 and 0. in the case of classification) to match your own.

1 Like