Untitled
unknown
python
3 years ago
5.7 kB
15
Indexable
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.keras import backend as K
from datasets.adult.Adult import Adult
def create_clients(X, y, num_clients):
# create a list of client names
client_names = ['clients_{}'.format(i + 1) for i in range(num_clients)]
# randomize the data
data = list(zip(X, y))
random.shuffle(data)
# shard data and place at each client
size = len(data) // num_clients
shards = [data[i:i + size] for i in range(0, size * num_clients, size)]
# number of clients must equal number of shards
assert (len(shards) == len(client_names))
return {client_names[i]: shards[i] for i in range(len(client_names))}
def batch_data(data_shard, bs=10):
data, label = zip(*data_shard)
dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
return dataset.shuffle(len(label)).batch(bs)
class SimpleMLP:
@staticmethod
def build(n_features):
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(n_features,)),
tf.keras.layers.Dense(10, activation='tanh'),
tf.keras.layers.Dense(1, activation='sigmoid'),
])
def weight_scalling_factor(clients_trn_data, client_name):
client_names = list(clients_trn_data.keys())
# get the bs
bs = list(clients_trn_data[client_name])[0][0].shape[0]
# first calculate the total training data points across clinets
global_count = sum(
[tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names]) * bs
# get the total number of data points held by a client
local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() * bs
return local_count / global_count
def scale_model_weights(weight, scalar):
weight_final = []
steps = len(weight)
for i in range(steps):
weight_final.append(scalar * weight[i])
return weight_final
def sum_scaled_weights(scaled_weight_list):
'''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
avg_grad = list()
# get the average grad across all client gradients
for grad_list_tuple in zip(*scaled_weight_list):
layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
avg_grad.append(layer_mean)
return avg_grad
def test_model(X_test, Y_test, model, comm_round):
logits = model.predict(X_test)
y_pred = np.greater_equal(logits, 0.5)
y_true = np.greater_equal(Y_test, 0.5)
acc = accuracy_score(y_pred, y_true)
print('comm_round: {} | global_acc: {:.3%}'.format(comm_round, acc))
return acc
if __name__ == '__main__':
df_X, df_y = Adult().get_data("", "_")
X = df_X.to_numpy()
y = df_y.to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
X_train = X_train.astype(np.float32)
y_train = y_train.astype(np.int32)
X_test = X_test.astype(np.float32)
y_test = y_test.astype(np.int32).reshape(len(y_test), 1)
clients = create_clients(X_train, y_train, 1)
# process and batch the training data for each client
clients_batched = dict()
for (client_name, data) in clients.items():
clients_batched[client_name] = batch_data(data)
# process and batch the test set
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))
lr = 0.01
comms_round = 100
smlp_global = SimpleMLP()
global_model = smlp_global.build(len(df_X.columns))
# commence global training loop
for comm_round in range(comms_round):
# get the global model's weights - will serve as the initial weights for all local models
global_weights = global_model.get_weights()
# initial list to collect local model weights after scalling
scaled_local_weight_list = []
# randomize client data - using keys
client_names = list(clients_batched.keys())
random.shuffle(client_names)
# loop through each client and create new local model
for client in client_names:
smlp_local = SimpleMLP()
local_model = smlp_local.build(len(df_X.columns))
local_model.compile(loss=tf.keras.losses.BinaryCrossentropy(),
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=lr),
metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
# set local model weight to the weight of the global model
local_model.set_weights(global_weights)
# fit local model with client's data
#print(clients_batched[client])
local_model.fit(clients_batched[client], epochs=10, verbose=0)
# scale the model weights and add to list
scaling_factor = weight_scalling_factor(clients_batched, client)
scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
scaled_local_weight_list.append(scaled_weights)
# clear session to free memory after each communication round
K.clear_session()
# to get the average over all the local model, we simply take the sum of the scaled weights
average_weights = sum_scaled_weights(scaled_local_weight_list)
# update global model
global_model.set_weights(average_weights)
# test global model and print out metrics after each communications round
for X_test, Y_test in test_batched:
global_acc = test_model(X_test, Y_test, global_model, comm_round)
Editor is loading...