How to change the number of active neurons during the pruning process?


I saw this image in the paper and this code,

class TinyCNN(nn.Module):
    """A very small CNN to classify the sklearn digits dataset.

    This class also allows pruning to a maximum of 10 active neurons, which
    should help keep the accumulator bit width low.
    """

    def __init__(self, n_classes, n_bits) -> None:
        """Construct the CNN with a configurable number of classes."""
        super().__init__()

        a_bits = n_bits
        w_bits = n_bits

        # This network has a total complexity of 1216 MAC
        self.q1 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.conv1 = qnn.QuantConv2d(1, 8, 3, stride=1, padding=0, weight_bit_width=w_bits)
        self.q2 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.conv2 = qnn.QuantConv2d(8, 16, 3, stride=2, padding=0, weight_bit_width=w_bits)
        self.q3 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.conv3 = qnn.QuantConv2d(16, 32, 2, stride=1, padding=0, weight_bit_width=w_bits)
        self.q4 = qnn.QuantIdentity(bit_width=a_bits, return_quant_tensor=True)
        self.fc1 = qnn.QuantLinear(
            32,
            n_classes,
            bias=True,
            weight_bit_width=w_bits,
        )

        # Enable pruning, prepared for training
        self.toggle_pruning(True)

    def toggle_pruning(self, enable):
        """Enables or removes pruning."""

        # Maximum number of active neurons (i.e. corresponding weight != 0)
        n_active = 12

        # Go through all the convolution layers
        for layer in (self.conv1, self.conv2, self.conv3):
            s = layer.weight.shape

            # Compute fan-in (number of inputs to a neuron)
            # and fan-out (number of neurons in the layer)
            st = [s[0], np.prod(s[1:])]

            # The number of input neurons (fan-in) is the product of
            # the kernel width x height x inChannels.
            if st[1] > n_active:
                if enable:
                    # This will create a forward hook to create a mask tensor that is multiplied
                    # with the weights during forward. The mask will contain 0s or 1s
                    prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
                else:
                    # When disabling pruning, the mask is multiplied with the weights
                    # and the result is stored in the weights member
                    prune.remove(layer, "weight")

    def forward(self, x):
        """Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""

        x = self.q1(x)
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.q2(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.q3(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = self.q4(x)
        x = x.flatten(1)
        x = self.fc1(x)
        return x

and I want to know how to change the number of active neurons during the pruning process? If I want to change the number of active neurons to 110 ,how should I modify the code?

The graph was produced by using the built-in NN which is a fully connected MLP model, while your code shows a convolutional model.

If you want to reproduce the graph you should use the NeuralNetworkClassifer class, following the docs here: Neural Networks - Concrete ML. Note the n_accum_bits parameter which you can vary and it will configure the number of active neurons for you.

If you want to change the number of active neurons in your code you should simply change the n_active variable in your pruning function. Make sure the total number of neurons is greater than the number of active ones (otherwise pruning will be disabled).

Thank you. If I want to prune my own model like the convolutional model code I show above, how should I modify my code?

The pruning function in that code is pretty generic. You could copy that function to your new model and change the for layer in (self.conv1, self.conv2, self.conv3): line to iterate over your own layers.