client.py
import copy
import pickle
import grpc
import numpy as np
import joblib
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression, SGDClassifier
from concrete import fhe
import federatedlearning_pb2
import federatedlearning_pb2_grpc
class Fed_Client:
def __init__(self, client_id, channel, test_size = 0.2, random_state = 10):
self.client_id = client_id
self.stub = federatedlearning_pb2_grpc.FederatedLearningServiceStub(channel)
self.X = None
self.y = None
self.model = None
self.quant_model = None
self.quant_model_w = None
self.quant_model_b = None
self.encrypt_model = None
self.avg_encrypt_model = None
self.qw_max = 1
self.qw_min = -1
self.qb_max = 1
self.qb_min = -1
self.n_bit = 7
self.test_size = test_size
self.random_state = random_state
self.X_samples, self.y_samples = None, None
self.circuit_sum_w, self.circuit_sum_b = None, None
self.scale_w, self.scale_b = None, None
self.Zp_w, self.Zp_b = None, None
self.circuit_sum_w_client = None
self.circuit_sum_b_client = None
self.w_globle = None
self.b_globle = None
self.testratio, self.randomint = 0.2, 42
self.X_train, self.y_train, self.X_test, self.y_test = None, None, None, None
def data_init(self, X_path, y_path):
"""Load and preprocess data."""
self.X = np.load(X_path, allow_pickle=True)
self.y = np.load(y_path, allow_pickle=True)
le = LabelEncoder()
self.y = le.fit_transform(self.y)
self.X = self.remove_zero_gene(self.X)
self.X = self.read_counts_to_CPM(self.X)
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, test_size=self.testratio, random_state=self.randomint)
datalen = len(self.y_train)
self.X_samples, self.y_samples = self.selRepsample(self.X_train, self.y_train)
self.X_train, self.y_train = self.X_train[int(0.16 * self.client_id * len(self.X_train)): int(0.16 * (self.client_id + 1) * len(self.X_train))], self.y_train[int(0.16 * self.client_id * len(self.y_train)): int(0.16 * (self.client_id + 1) * len(self.y_train))]
@staticmethod
def remove_zero_gene(X):
"""Remove zero-variance genes from the dataset."""
new_X_t = []
X_t = np.transpose(X)
for gene in X_t:
if np.sum(gene)!=0:
new_X_t.append(gene)
new_X_t = np.array(new_X_t)
return np.transpose(new_X_t)
@staticmethod
def read_counts_to_CPM(X):
"""Convert gene counts to Counts Per Million (CPM)."""
new_X = []
for sample in X:
tpm = sample / np.sum(sample)*1e6
tpm_log = np.log2(tpm + 1)
new_X.append(tpm_log)
return np.array(new_X)
@staticmethod
def selRepsample(X,y):
X_samples,y_samples = [],[]
y_set = set(list(y))
for y_label in y_set:
for x_sample,y_sample in zip(X,y):
if y_sample == y_label:
X_samples.append(x_sample)
y_samples.append(y_sample)
break
return X_samples, y_samples
@staticmethod
def quantization(weight, n_bits = 8):
max_X = np.max(weight)
# Output: 0.9507
min_X = np.min(weight)
# Output: 0.0581
max_q_value = 2 ** n_bits - 1
# Output: 127
range = max_X - min_X
# Output: 0.8926
scale = range / max_q_value
# Output: 0.2975
Zp = np.round((-min_X * max_q_value) / range)
# Output: 0
q_X = np.round(weight / scale) + Zp + 128
q_X = q_X.astype(np.uint8)
q_X = q_X.squeeze(0)
# print(q_X.shape)
return q_X
@staticmethod
def pre_quantization(weight):
max_X = np.max(weight)
# Output: 0.9507
min_X = np.min(weight)
# Output: 0.0581
return [max_X, min_X]
@staticmethod
def agg_weight_range(client_ranges):
agg_max_X = np.max(np.array(client_ranges)[:, 0])
agg_min_X = np.min(np.array(client_ranges)[:, 1])
return agg_max_X, agg_min_X
@staticmethod
def post_quantization(weight, agg_max_X = None, agg_min_X = None, n_bits = 8):
if not (agg_max_X and agg_min_X):
max_X = np.max(weight)
min_X = np.min(weight)
else:
max_X = agg_max_X
min_X = agg_min_X
max_q_value = 2 ** n_bits - 1
range = max_X - min_X
scale = range / max_q_value
# Output: 0.2975
Zp = np.round((-min_X * max_q_value) / range)
# Output: 0
q_X = np.round(weight / scale) + Zp + 128
q_X = q_X.astype(np.uint8)
# q_X = q_X.squeeze(0)
# print(q_X.shape)
return q_X, Zp, scale
@fhe.compiler({"scaled_weight_list1":"encrypted","scaled_weight_list2":"encrypted"})
def encrypted_sum_w(scaled_weight_list1, scaled_weight_list2):
w_avg = scaled_weight_list1 + scaled_weight_list2
return w_avg
@fhe.compiler({"scaled_weight_list1":"encrypted","scaled_weight_list2":"encrypted"})
def encrypted_sum_b(scaled_weight_list1, scaled_weight_list2):
w_avg = scaled_weight_list1 + scaled_weight_list2
return w_avg
def model_init(self):
"""Initialize model with the server."""
self.model = SGDClassifier(random_state=0,learning_rate = 'constant', eta0 = 0.000001,alpha = 0.5,
class_weight = 'balanced',loss='hinge')
self.model.fit(self.X_samples, self.y_samples)
# fix it
response = self.stub.InitializeModel(federatedlearning_pb2.ModelInitRequest(client_id=self.client_id, model_data=pickle.dumps(self.model)))
print("Model initialized:", response.client_id)
self.model = SGDClassifier(random_state=0,learning_rate = 'constant', eta0 = 0.000001,alpha = 0.5,
class_weight = 'balanced',loss='hinge',max_iter=5)
self.model.fit(self.X_samples, self.y_samples)
model_data = pickle.loads(response.model_data)
self.model.coef_ = copy.deepcopy(model_data.coef_)
self.model.intercept_ = copy.deepcopy(model_data.intercept_)
def train(self):
"""Perform training on the local dataset."""
# Example placeholder function
self.model.fit(self.X_train, self.y_train,coef_init=copy.deepcopy(self.w_globle), intercept_init=copy.deepcopy(self.b_globle)) # You would replace this with actual model training logic
# response = self.stub.TrainModel(federatedlearning_pb2.TrainRequest(client_id=self.client_id, model_data=b'model_data'))
print("Training complete")
def send_quantization_parameters(self):
"""Send quantization parameters to the server."""
self.qw_max, self.qw_min = self.pre_quantization(self.model.coef_)
self.qb_max, self.qb_min = self.pre_quantization(self.model.intercept_)
params = self.stub.SendQuantizationParameters(federatedlearning_pb2.QuantizationParameters(
client_id=self.client_id, Q_max=self.qw_max, Q_min=self.qw_min, n_bit=self.n_bit
))
bias = self.stub.SendQuantizationbias(federatedlearning_pb2.QuantizationParameters(
client_id=self.client_id, Q_max=self.qb_max, Q_min=self.qb_min, n_bit=self.n_bit
))
print("Quantization parameters sent:", params, bias)
def get_aggregated_quantization_parameters(self):
"""Retrieve aggregated quantization parameters from the server."""
params = self.stub.GetAggregatedQuantizationParameters(federatedlearning_pb2.ClientId(client_id=self.client_id))
print("Received aggregated quantization parameters:", params)
self.qw_max = params.Qw_max
self.qw_min = params.Qw_min
self.qb_max = params.Qb_max
self.qb_min = params.Qb_min
print(params)
def keySync(self):
# Todo
scaled_local_weight_list = []
scaled_local_b_list = []
scaled_local_weight_list.append(self.model.coef_)
scaled_local_b_list.append(self.model.intercept_)
q_w = self.quantization(scaled_local_weight_list, n_bits=self.n_bit)
q_b = self.quantization(scaled_local_b_list, n_bits=self.n_bit)
circuit_sum_w = self.encrypted_sum_w.compile([(q_w, q_w*1),(q_w, q_w*2),(q_w, q_w*3),(q_w, q_w*4),(q_w, q_w*5),(q_w*2, q_w*2)])
circuit_sum_b = self.encrypted_sum_b.compile([(q_b, q_b*1),(q_b, q_b*2),(q_b, q_b*3),(q_b, q_b*4),(q_b, q_b*5),(q_b*2, q_b*2)])
# circuit_sum_w = self.encrypted_sum_w.compile([(q_w, q_w*1),(q_w, q_w*2)])
# circuit_sum_b = self.encrypted_sum_b.compile([(q_b, q_b*1),(q_b, q_b*2)])
circuit_sum_w.server.save("./server_w.zip")
circuit_sum_b.server.save("./server_b.zip")
upload = []
with open("./server_w.zip", "rb") as f:
zip_data_w = f.read()
upload.append(zip_data_w)
with open("./server_b.zip", "rb") as f:
zip_data_b = f.read()
upload.append(zip_data_b)
circuit_zip = pickle.dumps(upload)
print('here')
response = self.stub.SyncKeys(federatedlearning_pb2.RequestKeys(client_id=self.client_id, keys=circuit_zip))
print('RequestKeys')
# print(response)
self.key = pickle.loads(response.keys)
client_specs_w = fhe.ClientSpecs.deserialize(self.key[0])
client_specs_b = fhe.ClientSpecs.deserialize(self.key[1])
self.circuit_sum_w_client = fhe.Client(client_specs_w)
self.circuit_sum_b_client = fhe.Client(client_specs_b)
self.circuit_sum_w_client.keys.load_if_exists_generate_and_save_otherwise("./key", seed=420)
self.circuit_sum_b_client.keys.load_if_exists_generate_and_save_otherwise("./key", seed=420)
serialized_evaluation_keys_w: bytes = self.circuit_sum_w_client.evaluation_keys.serialize()
serialized_evaluation_keys_b: bytes = self.circuit_sum_b_client.evaluation_keys.serialize()
Evkeys = pickle.dumps([serialized_evaluation_keys_w, serialized_evaluation_keys_b])
response = self.stub.SyncEvKeys(federatedlearning_pb2.RequestEvKeys(client_id=self.client_id, keys=Evkeys ))
print("Received keys and set")
def send_encrypted_model(self):
"""Send an encrypted model to the server."""
self.quant_model_w, self.Zp_w, self.scale_w = self.post_quantization(self.model.coef_, self.qw_max, self.qb_max, n_bits=self.n_bit)
self.quant_model_b, self.Zp_b, self.scale_b = self.post_quantization(self.model.intercept_, self.qw_max, self.qb_max, n_bits=self.n_bit)
import sys
print(sys.getsizeof(self.quant_model_w))
print(self.quant_model_b)
encrypted_weight = self.circuit_sum_w_client.encrypt(None, self.quant_model_w)
encrypted_bias = self.circuit_sum_b_client.encrypt(None, self.quant_model_b)
print(encrypted_bias)
print(sys.getsizeof(encrypted_weight),sys.getsizeof(encrypted_bias))
serialized_en_w = encrypted_weight[1].serialize()
serialized_en_b = encrypted_bias[1].serialize()
print(sys.getsizeof(serialized_en_w),sys.getsizeof(serialized_en_b))
serialized_en_wb = pickle.dumps([serialized_en_w, serialized_en_b])
print(sys.getsizeof(serialized_en_wb))
self.encrypt_model = serialized_en_wb
response = self.stub.SendEncryptedModel(federatedlearning_pb2.EncryptedModel(client_id=self.client_id, encrypted_data=self.encrypt_model))
print("Encrypted model sent:", response)
def get_average_encrypted_model(self):
"""Retrieve the average encrypted model from the server."""
en_model = self.stub.GetAverageEncryptedModel(federatedlearning_pb2.ClientId(client_id=self.client_id))
print(type(en_model))
en_model = pickle.loads(en_model.encrypted_data)
self.encrypt_avg_w = fhe.Value.deserialize(en_model[0])
self.encrypt_avg_b = fhe.Value.deserialize(en_model[1])
if self.model.fit_intercept:
quant_avg_w = self.circuit_sum_w_client.decrypt(self.encrypt_avg_w)
quant_avg_b = self.circuit_sum_b_client.decrypt(self.encrypt_avg_b)
print(quant_avg_b)
else:
quant_avg_w = self.circuit_sum_w_client.decrypt(self.encrypt_avg_w)
if self.model.fit_intercept:
quant_avg_w = (quant_avg_w / 5 ) - 128
W_dequant = (quant_avg_w - self.Zp_w) * self.scale_w
w_global = W_dequant
quant_avg_b = (quant_avg_b / 5) - 128
print(quant_avg_b)
B_dequant = (quant_avg_b - self.Zp_b) * self.scale_b
b_global = B_dequant
print(self.model.intercept_)
self.model.coef_ = copy.deepcopy(w_global)
self.model.intercept_ = copy.deepcopy(b_global)
print(b_global)
self.w_globle = copy.deepcopy(w_global)
self.b_globle = copy.deepcopy(b_global)
else:
quant_avg_w = (quant_avg_w / 5 ) - 128
W_dequant = (quant_avg_w - self.Zp_w) * self.scale_w
w_global = W_dequant
self.model.coef_ = copy.deepcopy(w_global)
self.w_globle = copy.deepcopy(w_global)
print("Received average encrypted model")
self.avg_encrypt_model = en_model
def test(self):
y_pred = self.model.predict(self.X_test)
acc = accuracy_score(y_pred, self.y_test)
f1 = f1_score(y_pred, self.y_test, average='weighted')
print('**b_globle**'*5)
print(self.b_globle)
print('**final test**'*5)
print(acc, f1)
print('**final test**'*5)
def testQuant(self, w=None, b=None):
if w is None:
w = self.quant_model_w
b = self.quant_model_b
testname = 'Quant Test'
else:
testname = 'Encrypt Test'
quant_avg_w = w + w + w + w + w
quant_avg_b = b + b + b + b + b
quant_avg_w = (quant_avg_w / 5) - 128
W_dequant = (quant_avg_w - self.Zp_w) * self.scale_w
w_global = W_dequant
quant_avg_b = (quant_avg_b / 5) - 128
print(quant_avg_b)
B_dequant = (quant_avg_b - self.Zp_b) * self.scale_b
b_global = B_dequant
print(self.model.intercept_)
self.model.coef_ = copy.deepcopy(w_global)
self.model.intercept_ = copy.deepcopy(b_global)
y_pred = self.model.predict(self.X_test)
acc = accuracy_score(y_pred, self.y_test)
f1 = f1_score(y_pred, self.y_test, average='weighted')
print('**{}**'.format(testname)*5)
print(acc, f1)
print('**{}**'.format(testname)*5)
def testEncrypt(self):
encrypted_weight = self.circuit_sum_w_client.encrypt(None, self.quant_model_w)
encrypted_bias = self.circuit_sum_b_client.encrypt(None, self.quant_model_b)
de_weight = self.circuit_sum_w_client.decrypt(encrypted_weight[1])
de_bias = self.circuit_sum_b_client.decrypt(encrypted_bias[1])
print("--Enc bias--"*5)
print(self.quant_model_b)
print(self.quant_model_b.astype(np.uint8))
print(de_bias)
self.testQuant(w=de_weight, b=de_bias)
if __name__ == '__main__':
channel = grpc.insecure_channel('localhost:50055',
options=[
('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1),
('grpc.keepalive_timeout_ms', 600000)
])
client = Fed_Client(client_id=4, channel=channel)
client.data_init('/home/shenbc/sc/naivefl/federatedSinglecell/X.npy', '/home/shenbc/sc/naivefl/federatedSinglecell/y.npy')
client.model_init()
client.keySync()
for i in range(50):
print('--Training Round{}--'.format(i)*5)
client.train()
client.send_quantization_parameters()
client.get_aggregated_quantization_parameters()
client.send_encrypted_model()
client.testQuant()
client.testEncrypt()
client.get_average_encrypted_model()
client.test()
server.py
from concurrent import futures
import pickle
import grpc
import numpy as np
import joblib
import federatedlearning_pb2
import federatedlearning_pb2_grpc
from concrete import fhe
import threading
import time
from concurrent import futures
condition = threading.Condition()
request_counter = 0
all_requests_processed = False
Quant_request_counter = 0
Quant_all_requests_processed = False
class FederatedLearningServicer(federatedlearning_pb2_grpc.FederatedLearningServiceServicer):
def __init__(self):
# This could store model parameters, quantization details, etc.
self.models = {} # Dictionary to store model states
self.encrypted_weight = {}
self.encrypted_bias = {}
self.quant_params = {}
self.N_BITS = 8
self.circuit_sum_w = None
self.circuit_sum_b = None
self.key_w = None
self.key_b = None
self.Evkey_w = None
self.Evkey_b = None
self.encrypt_avg = None
self.w_maxmin = []
self.b_maxmin = []
self.max_clients = 5
self.qu_coef_max, self.qu_coef_min =None,None
self.qu_intercept_max, self.qu_intercept_min =None,None
def InitializeModel(self, request, context):
# Initialize a model based on client_id
print(f"Initializing model for client {request.client_id}")
self.models[request.client_id] = pickle.loads(request.model_data) # Placeholder for model data
print('hhhhh')
return federatedlearning_pb2.ModelInitRequest(client_id=request.client_id, model_data=pickle.dumps(self.models[request.client_id]))
@staticmethod
def agg_weight_range(client_ranges):
agg_max_X = np.max(np.array(client_ranges)[:, 0])
agg_min_X = np.min(np.array(client_ranges)[:, 1])
return agg_max_X, agg_min_X
def SyncKeys(self, request, context):
request_w = pickle.loads(request.keys)[0]
request_b = pickle.loads(request.keys)[1]
# print(request_w)
with open("./server_save_w.zip", "wb+") as f:
f.write(request_w)
with open("./server_save_b.zip", "wb+") as f:
f.write(request_b)
print('Key Init')
self.circuit_sum_w = fhe.Server.load("./server_save_w.zip")
self.circuit_sum_b = fhe.Server.load("./server_save_b.zip")
print('load')
serialized_client_specs_w: str = self.circuit_sum_w.client_specs.serialize()
serialized_client_specs_b: str = self.circuit_sum_b.client_specs.serialize()
print('serialize')
client_specs = pickle.dumps([serialized_client_specs_w, serialized_client_specs_b])
return federatedlearning_pb2.SendKeys(client_id=request.client_id, keys=client_specs)
def SyncEvKeys(self, request, context):
Evkeys = pickle.loads(request.keys)
self.Evkey_w = fhe.EvaluationKeys.deserialize(Evkeys[0])
self.Evkey_b = fhe.EvaluationKeys.deserialize(Evkeys[1])
print('SyncEvKeys')
return federatedlearning_pb2.SendEvKeys(client_id=request.client_id)
def TrainModel(self, request, context):
# Simulate training the model with provided data
print(f"Training model for client {request.client_id}")
# Here, just echoing back the received data for simplicity
return federatedlearning_pb2.TrainRequest(client_id=request.client_id, model_data=request.model_data)
def SendQuantizationParameters(self, request, context):
# Receive and store quantization parameters
print(f"Received quantization parameters from client {request.client_id}: Q_max={request.Q_max}, Q_min={request.Q_min}, n_bits={request.n_bit}")
self.w_maxmin.append([request.Q_max, request.Q_min])
# Return an acknowledgement immediately
return federatedlearning_pb2.Acknowledgement(message="Parameters received, aggregation pending.")
def SendQuantizationbias(self, request, context):
# Receive and store quantization parameters
print(f"Received quantization parameters from client {request.client_id}: Q_max={request.Q_max}, Q_min={request.Q_min}, n_bits={request.n_bit}")
self.b_maxmin.append([request.Q_max, request.Q_min])
# Return an acknowledgement immediately
return federatedlearning_pb2.Acknowledgement(message="Parameters received, aggregation pending.")
def GetAggregatedQuantizationParameters(self, request, context):
# Return aggregated quantization parameters after processing
print(f"Sending aggregated quantization parameters to client {request.client_id}")
self.qu_coef_max, self.qu_coef_min = None,None
self.qu_intercept_max, self.qu_intercept_min = None,None
global Quant_request_counter
global Quant_all_requests_processed
with condition:
Quant_request_counter += 1
if Quant_request_counter == self.max_clients:
Quant_all_requests_processed = True
self.qu_coef_max, self.qu_coef_min = self.agg_weight_range(self.w_maxmin)
self.qu_intercept_max, self.qu_intercept_min = self.agg_weight_range(self.b_maxmin)
print(self.w_maxmin)
print(self.b_maxmin)
condition.notify_all()
else:
while not Quant_all_requests_processed:
condition.wait()
return federatedlearning_pb2.ServerQuantizationParameters(Qw_max=self.qu_coef_max, Qw_min=self.qu_coef_min, Qb_max=self.qu_intercept_max, Qb_min=self.qu_intercept_min)
def SendEncryptedModel(self, request, context):
# Handle receiving encrypted model
print(f"Received encrypted model from client {request.client_id}")
en_wb = pickle.loads(request.encrypted_data)
self.models[request.client_id] = en_wb
self.encrypted_weight[request.client_id] = (None, fhe.Value.deserialize(en_wb[0]))
self.encrypted_bias[request.client_id] = (None, fhe.Value.deserialize(en_wb[1]))
global all_requests_processed
all_requests_processed = False
return federatedlearning_pb2.Acknowledgement(message="Encrypted model received.")
def GetAverageEncryptedModel(self, request, context):
# Send averaged encrypted model to client
print(f"Sending average encrypted model to client {request.client_id}")
rangelist = np.arange(1, len(self.encrypted_weight), 1, np.int32)
encrypt_avg_w = self.encrypted_weight[list(self.encrypted_weight.keys())[0]][1]
encrypt_avg_b = self.encrypted_bias[list(self.encrypted_bias.keys())[0]][1]
# print(encrypt_avg_b)
global request_counter
global all_requests_processed
with condition:
request_counter += 1
if request_counter == self.max_clients:
all_requests_processed = True
for i in rangelist:
print(i)
encrypt_avg_w = self.circuit_sum_w.run((self.encrypted_weight[list(self.encrypted_weight.keys())[i]][1], encrypt_avg_w),evaluation_keys=self.Evkey_w)
encrypt_avg_b = self.circuit_sum_b.run((self.encrypted_bias[list(self.encrypted_bias.keys())[i]][1], encrypt_avg_b), evaluation_keys=self.Evkey_b)
print('runrun')
self.encrypt_avg = pickle.dumps([encrypt_avg_w.serialize(),encrypt_avg_b.serialize()])
condition.notify_all()
else:
while not all_requests_processed:
condition.wait()
# while len(self.encrypted_weight)<3 :
# import time
# time.sleep(10)
# print('waiting')
print('herehhhhhhhhhhhhhh')
return federatedlearning_pb2.EncryptedModel(client_id=request.client_id, encrypted_data=self.encrypt_avg)
def GetServerState(self, request, context):
state = federatedlearning_pb2.ServerState(
encrypted_weight=pickle.dumps(self.encrypted_weight),
encrypted_bias=pickle.dumps(self.encrypted_bias),
quant_params=pickle.dumps(self.quant_params),
N_BITS=self.N_BITS,
circuit_sum_w=pickle.dumps(self.circuit_sum_w if self.circuit_sum_w else b''),
circuit_sum_b=pickle.dumps(self.circuit_sum_b if self.circuit_sum_b else b''),
key_w=pickle.dumps(self.key_w if self.key_w else b''),
key_b=pickle.dumps(self.key_b if self.key_b else b''),
Evkey_w=pickle.dumps(self.Evkey_w if self.Evkey_w else b''),
Evkey_b=pickle.dumps(self.Evkey_b if self.Evkey_b else b''),
w_maxmin=pickle.dumps(self.w_maxmin),
b_maxmin=pickle.dumps(self.b_maxmin)
)
return state
def DecryptModel(self, request, context):
# Decrypt the model for the client
print(f"Decrypting model for client {request.client_id}")
return federatedlearning_pb2.EncryptedModel(client_id=request.client_id, encrypted_data=b'decrypted_model_data')
@staticmethod
@fhe.compiler({"scaled_weight_list1":"encrypted","scaled_weight_list2":"encrypted"})
def encrypted_sum_w(scaled_weight_list1, scaled_weight_list2):
w_avg = scaled_weight_list1 + scaled_weight_list2
return w_avg
@staticmethod
@fhe.compiler({"scaled_weight_list1":"encrypted","scaled_weight_list2":"encrypted"})
def encrypted_sum_b(scaled_weight_list1, scaled_weight_list2):
w_avg = scaled_weight_list1 + scaled_weight_list2
return w_avg
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
options=[
('grpc.max_send_message_length', -1), # No limit for sending
('grpc.max_receive_message_length', -1) # No limit for receiving
])
federatedlearning_pb2_grpc.add_FederatedLearningServiceServicer_to_server(
FederatedLearningServicer(), server)
server.add_insecure_port('localhost:50055')
server.start()
print("Server started at localhost:50055")
server.wait_for_termination()
if __name__ == '__main__':
serve()