Hello,How to use L1-norm unstructured pruning in cifar10?As mentioned in the paper “Deep Neural Networks for Encrypted Inference with TFHE”.
I added a toggle_pruning
function to the original network, but after pruning, it displayed 'RuntimeError: NoParametersFound'
during compilation, or when running FHE, it displayed ‘ValueError: vector::_M_default_append’
Here is my model code, and I have added a pruning function named toggle_pruning`. Or can you provide the complete code for this paper “Deep Neural Networks for Encrypted Inference with TFHE”.
import torch
from brevitas.core.restrict_val import RestrictValueType
from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear
from torch.nn import AvgPool2d, BatchNorm1d, BatchNorm2d, Module, ModuleList
from torch.nn.utils import prune
import numpy as np
from .common import CommonActQuant, CommonWeightQuant
from .tensor_norm import TensorNorm
CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)]
INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
LAST_FC_IN_FEATURES = 512
LAST_FC_PER_OUT_CH_SCALING = False
POOL_SIZE = 2
KERNEL_SIZE = 3
class CNV(Module):
def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch):
super(CNV, self).__init__()
self.conv_features = ModuleList()
self.linear_features = ModuleList()
self.conv_features.append(
QuantIdentity( # for Q1.7 input format
act_quant=CommonActQuant,
return_quant_tensor=True,
bit_width=in_bit_width,
min_val=-1.0,
max_val=1.0 - 2.0 ** (-7),
narrow_range=False,
restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
)
)
for out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
self.conv_features.append(
QuantConv2d(
kernel_size=KERNEL_SIZE,
in_channels=in_ch,
out_channels=out_ch,
bias=True,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width,
)
)
in_ch = out_ch
self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
self.conv_features.append(
QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
)
if is_pool_enabled:
self.conv_features.append(AvgPool2d(kernel_size=2))
self.conv_features.append(
QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
)
for in_features, out_features in INTERMEDIATE_FC_FEATURES:
self.linear_features.append(
QuantLinear(
in_features=in_features,
out_features=out_features,
bias=False,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width,
)
)
self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
self.linear_features.append(
QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
)
self.linear_features.append(
QuantLinear(
in_features=LAST_FC_IN_FEATURES,
out_features=num_classes,
bias=False,
weight_quant=CommonWeightQuant,
weight_bit_width=weight_bit_width,
)
)
self.linear_features.append(TensorNorm())
# self.toggle_pruning(True)
for m in self.modules():
if isinstance(m, QuantConv2d) or isinstance(m, QuantLinear):
torch.nn.init.uniform_(m.weight.data, -1, 1)
def clip_weights(self, min_val, max_val):
for mod in self.conv_features:
if isinstance(mod, QuantConv2d):
mod.weight.data.clamp_(min_val, max_val)
for mod in self.linear_features:
if isinstance(mod, QuantLinear):
mod.weight.data.clamp_(min_val, max_val)
def toggle_pruning(self, enable):
"""Enables or removes pruning."""
# Maximum number of active neurons (i.e. corresponding weight != 0)
n_active = 110
# [256, 512, 512, 10]
# Go through all the convolution layers
# for layer in self.modules():
for mod in self.linear_features:
if isinstance(mod, QuantLinear):
if enable:
prune.l1_unstructured(mod, "weight", (mod.weight.shape[1] - n_active) * mod.weight.shape[0])
else:
prune.remove(mod, "weight")
def forward(self, x):
for mod in self.conv_features:
x = mod(x)
x = torch.flatten(x, 1)
for mod in self.linear_features:
x = mod(x)
return x
def cnv(cfg):
weight_bit_width = cfg.getint("QUANT", "WEIGHT_BIT_WIDTH")
act_bit_width = cfg.getint("QUANT", "ACT_BIT_WIDTH")
in_bit_width = cfg.getint("QUANT", "IN_BIT_WIDTH")
num_classes = cfg.getint("MODEL", "NUM_CLASSES")
in_channels = cfg.getint("MODEL", "IN_CHANNELS")
net = CNV(
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
in_bit_width=in_bit_width,
num_classes=num_classes,
in_ch=in_channels,
)
return net