mail@pastecode.io avatar
a year ago
5.7 kB
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.keras import backend as K

from datasets.adult.Adult import Adult

def create_clients(X, y, num_clients):
    # create a list of client names
    client_names = ['clients_{}'.format(i + 1) for i in range(num_clients)]

    # randomize the data
    data = list(zip(X, y))

    # shard data and place at each client
    size = len(data) // num_clients
    shards = [data[i:i + size] for i in range(0, size * num_clients, size)]

    # number of clients must equal number of shards
    assert (len(shards) == len(client_names))

    return {client_names[i]: shards[i] for i in range(len(client_names))}

def batch_data(data_shard, bs=10):
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))

    return dataset.shuffle(len(label)).batch(bs)

class SimpleMLP:
    def build(n_features):
        return tf.keras.models.Sequential([
            tf.keras.layers.Dense(10, activation='tanh'),
            tf.keras.layers.Dense(1, activation='sigmoid'),

def weight_scalling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    # get the bs
    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    # first calculate the total training data points across clinets
    global_count = sum(
        [tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names]) * bs
    # get the total number of data points held by a client
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() * bs

    return local_count / global_count

def scale_model_weights(weight, scalar):
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])

    return weight_final

def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    # get the average grad across all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)

    return avg_grad

def test_model(X_test, Y_test, model, comm_round):
    logits = model.predict(X_test)
    y_pred = np.greater_equal(logits, 0.5)
    y_true = np.greater_equal(Y_test, 0.5)
    acc = accuracy_score(y_pred, y_true)

    print('comm_round: {} | global_acc: {:.3%}'.format(comm_round, acc))

    return acc

if __name__ == '__main__':
    df_X, df_y = Adult().get_data("", "_")
    X = df_X.to_numpy()
    y = df_y.to_numpy()

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
    X_train = X_train.astype(np.float32)
    y_train = y_train.astype(np.int32)
    X_test = X_test.astype(np.float32)
    y_test = y_test.astype(np.int32).reshape(len(y_test), 1)

    clients = create_clients(X_train, y_train, 1)

    # process and batch the training data for each client
    clients_batched = dict()
    for (client_name, data) in clients.items():
        clients_batched[client_name] = batch_data(data)

    # process and batch the test set
    test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

    lr = 0.01
    comms_round = 100

    smlp_global = SimpleMLP()
    global_model = smlp_global.build(len(df_X.columns))

    # commence global training loop
    for comm_round in range(comms_round):
        # get the global model's weights - will serve as the initial weights for all local models
        global_weights = global_model.get_weights()

        # initial list to collect local model weights after scalling
        scaled_local_weight_list = []

        # randomize client data - using keys
        client_names = list(clients_batched.keys())

        # loop through each client and create new local model
        for client in client_names:
            smlp_local = SimpleMLP()
            local_model = smlp_local.build(len(df_X.columns))

            # set local model weight to the weight of the global model

            # fit local model with client's data
            local_model.fit(clients_batched[client], epochs=10, verbose=0)

            # scale the model weights and add to list
            scaling_factor = weight_scalling_factor(clients_batched, client)
            scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)

            # clear session to free memory after each communication round

        # to get the average over all the local model, we simply take the sum of the scaled weights
        average_weights = sum_scaled_weights(scaled_local_weight_list)

        # update global model

        # test global model and print out metrics after each communications round
        for X_test, Y_test in test_batched:
            global_acc = test_model(X_test, Y_test, global_model, comm_round)