
mail@pastecode.io avatar
a year ago
4.4 kB

def synchronized_weight_matching(models, ps: PermutationSpec, method, symbols, combinations, max_iter=100):
    Find a permutation of params_b to make them match params_a.

    :param ps: PermutationSpec
    :param target: the parameters to match
    :param to_permute: the parameters to permute
    params = {s: m.model.state_dict() for s, m in models.items()}

    # For a MLP of 4 layers it would be something like {'P_0': 512, 'P_1': 512, 'P_2': 512, 'P_3': 256}. Input and output dim are never permuted.
    perm_sizes = {p: params["a"][axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

    # {'a': {'b': 'P0': P_AB_0, 'c': ....}, .... }
    perm_indices = {
        symb: {
            other_symb: {p: torch.arange(n) for p, n in perm_sizes.items()} for other_symb in symbols.difference(symb)
        for symb in symbols

    # e.g. P0, P1, ..
    perm_names = list(perm_indices["a"]["b"].keys())

    for iteration in tqdm(range(max_iter), desc="Weight matching"):
        progress = False

        # iterate over the permutation matrices in random order
        for p_ix in torch.randperm(len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]

            similarities = {
                symb: {other_symb: torch.zeros((n, n)) for other_symb in symbols.difference(symb)} for symb in symbols

            # all the params that are permuted by this permutation matrix, together with the axis on which it acts
            # e.g. ('layer_0.weight', 0), ('layer_0.bias', 0), ('layer_1.weight', 0)..
            params_and_axes = ps.perm_to_axes[p]

            for params_name, axis in params_and_axes:
                w_a = params["a"][params_name]
                w_b = params["b"][params_name]
                w_c = params["c"][params_name]

                assert w_a.shape == w_b.shape

                perms_to_apply = ps.axes_to_perm[params_name]

                w_b_a = get_permuted_param(w_b, perms_to_apply, perm_indices["b"]["a"], except_axis=axis)
                w_c_a = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["a"], except_axis=axis)
                w_c_b = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["b"], except_axis=axis)

                w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))
                w_b_a = torch.moveaxis(w_b_a, axis, 0).reshape((n, -1))
                w_c_a = torch.moveaxis(w_c_a, axis, 0).reshape((n, -1))
                w_c_b = torch.moveaxis(w_c_b, axis, 0).reshape((n, -1))

                similarities["a"]["b"] = w_a @ w_b_a.T
                similarities["a"]["c"] = w_a @ w_c_a.T
                similarities["b"]["c"] = w_b @ w_c_b.T

            similarities["b"]["a"] = similarities["a"]["b"].T
            similarities["c"]["a"] = similarities["a"]["c"].T
            similarities["c"]["b"] = similarities["b"]["c"].T

            old_similarity = 0.0
            for source, target in [("a", "b"), ("a", "c"), ("c", "b")]:
                pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p])
                old_similarity += pairwise_sim

            uber_matrix = three_models_uber_matrix(
                similarities["a"]["b"], similarities["a"]["c"], similarities["b"]["c"], perm_dim=n

            sync_matrix = optimize_synchronization(uber_matrix, n, method)

            sync_perm_indices = parse_sync_matrix(sync_matrix, n, symbols, combinations)

            for source, target in [("a", "b"), ("a", "c"), ("b", "c")]:
                perm_indices[source][target][p] = perm_matrix_to_perm_indices(sync_perm_indices[(source, target)])
                perm_indices[target][source][p] = perm_indices[source][target][p].T

            new_similarity = 0.0
            for source, target in [("a", "b"), ("a", "c"), ("b", "c")]:
                pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p])
                new_similarity += pairwise_sim

                f"Iteration {iteration}, Permutation {p}: {(new_similarity - old_similarity).sum(dim=-1).sum(dim=-1)}"

            progress = progress or torch.any(new_similarity > old_similarity + 1e-12)

        if not progress:

    return perm_indices