Untitled
unknown
python
2 years ago
5.7 kB
7
Indexable
import random import numpy as np from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import tensorflow as tf from tensorflow.keras import backend as K from datasets.adult.Adult import Adult def create_clients(X, y, num_clients): # create a list of client names client_names = ['clients_{}'.format(i + 1) for i in range(num_clients)] # randomize the data data = list(zip(X, y)) random.shuffle(data) # shard data and place at each client size = len(data) // num_clients shards = [data[i:i + size] for i in range(0, size * num_clients, size)] # number of clients must equal number of shards assert (len(shards) == len(client_names)) return {client_names[i]: shards[i] for i in range(len(client_names))} def batch_data(data_shard, bs=10): data, label = zip(*data_shard) dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label))) return dataset.shuffle(len(label)).batch(bs) class SimpleMLP: @staticmethod def build(n_features): return tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(n_features,)), tf.keras.layers.Dense(10, activation='tanh'), tf.keras.layers.Dense(1, activation='sigmoid'), ]) def weight_scalling_factor(clients_trn_data, client_name): client_names = list(clients_trn_data.keys()) # get the bs bs = list(clients_trn_data[client_name])[0][0].shape[0] # first calculate the total training data points across clinets global_count = sum( [tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names]) * bs # get the total number of data points held by a client local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() * bs return local_count / global_count def scale_model_weights(weight, scalar): weight_final = [] steps = len(weight) for i in range(steps): weight_final.append(scalar * weight[i]) return weight_final def sum_scaled_weights(scaled_weight_list): '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights''' avg_grad = list() # get the average grad across all client gradients for grad_list_tuple in zip(*scaled_weight_list): layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0) avg_grad.append(layer_mean) return avg_grad def test_model(X_test, Y_test, model, comm_round): logits = model.predict(X_test) y_pred = np.greater_equal(logits, 0.5) y_true = np.greater_equal(Y_test, 0.5) acc = accuracy_score(y_pred, y_true) print('comm_round: {} | global_acc: {:.3%}'.format(comm_round, acc)) return acc if __name__ == '__main__': df_X, df_y = Adult().get_data("", "_") X = df_X.to_numpy() y = df_y.to_numpy() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) X_train = X_train.astype(np.float32) y_train = y_train.astype(np.int32) X_test = X_test.astype(np.float32) y_test = y_test.astype(np.int32).reshape(len(y_test), 1) clients = create_clients(X_train, y_train, 1) # process and batch the training data for each client clients_batched = dict() for (client_name, data) in clients.items(): clients_batched[client_name] = batch_data(data) # process and batch the test set test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test)) lr = 0.01 comms_round = 100 smlp_global = SimpleMLP() global_model = smlp_global.build(len(df_X.columns)) # commence global training loop for comm_round in range(comms_round): # get the global model's weights - will serve as the initial weights for all local models global_weights = global_model.get_weights() # initial list to collect local model weights after scalling scaled_local_weight_list = [] # randomize client data - using keys client_names = list(clients_batched.keys()) random.shuffle(client_names) # loop through each client and create new local model for client in client_names: smlp_local = SimpleMLP() local_model = smlp_local.build(len(df_X.columns)) local_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=lr), metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)]) # set local model weight to the weight of the global model local_model.set_weights(global_weights) # fit local model with client's data #print(clients_batched[client]) local_model.fit(clients_batched[client], epochs=10, verbose=0) # scale the model weights and add to list scaling_factor = weight_scalling_factor(clients_batched, client) scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor) scaled_local_weight_list.append(scaled_weights) # clear session to free memory after each communication round K.clear_session() # to get the average over all the local model, we simply take the sum of the scaled weights average_weights = sum_scaled_weights(scaled_local_weight_list) # update global model global_model.set_weights(average_weights) # test global model and print out metrics after each communications round for X_test, Y_test in test_batched: global_acc = test_model(X_test, Y_test, global_model, comm_round)
Editor is loading...