def synchronized_weight_matching(models, ps: PermutationSpec, method, symbols, combinations, max_iter=100):
"""
Find a permutation of params_b to make them match params_a.
:param ps: PermutationSpec
:param target: the parameters to match
:param to_permute: the parameters to permute
"""
params = {s: m.model.state_dict() for s, m in models.items()}
# For a MLP of 4 layers it would be something like {'P_0': 512, 'P_1': 512, 'P_2': 512, 'P_3': 256}. Input and output dim are never permuted.
perm_sizes = {p: params["a"][axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
# {'a': {'b': 'P0': P_AB_0, 'c': ....}, .... }
perm_indices = {
symb: {
other_symb: {p: torch.arange(n) for p, n in perm_sizes.items()} for other_symb in symbols.difference(symb)
}
for symb in symbols
}
# e.g. P0, P1, ..
perm_names = list(perm_indices["a"]["b"].keys())
for iteration in tqdm(range(max_iter), desc="Weight matching"):
progress = False
# iterate over the permutation matrices in random order
for p_ix in torch.randperm(len(perm_names)):
p = perm_names[p_ix]
n = perm_sizes[p]
similarities = {
symb: {other_symb: torch.zeros((n, n)) for other_symb in symbols.difference(symb)} for symb in symbols
}
# all the params that are permuted by this permutation matrix, together with the axis on which it acts
# e.g. ('layer_0.weight', 0), ('layer_0.bias', 0), ('layer_1.weight', 0)..
params_and_axes = ps.perm_to_axes[p]
for params_name, axis in params_and_axes:
w_a = params["a"][params_name]
w_b = params["b"][params_name]
w_c = params["c"][params_name]
assert w_a.shape == w_b.shape
perms_to_apply = ps.axes_to_perm[params_name]
w_b_a = get_permuted_param(w_b, perms_to_apply, perm_indices["b"]["a"], except_axis=axis)
w_c_a = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["a"], except_axis=axis)
w_c_b = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["b"], except_axis=axis)
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))
w_b_a = torch.moveaxis(w_b_a, axis, 0).reshape((n, -1))
w_c_a = torch.moveaxis(w_c_a, axis, 0).reshape((n, -1))
w_c_b = torch.moveaxis(w_c_b, axis, 0).reshape((n, -1))
similarities["a"]["b"] = w_a @ w_b_a.T
similarities["a"]["c"] = w_a @ w_c_a.T
similarities["b"]["c"] = w_b @ w_c_b.T
similarities["b"]["a"] = similarities["a"]["b"].T
similarities["c"]["a"] = similarities["a"]["c"].T
similarities["c"]["b"] = similarities["b"]["c"].T
old_similarity = 0.0
for source, target in [("a", "b"), ("a", "c"), ("c", "b")]:
pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p])
old_similarity += pairwise_sim
uber_matrix = three_models_uber_matrix(
similarities["a"]["b"], similarities["a"]["c"], similarities["b"]["c"], perm_dim=n
)
sync_matrix = optimize_synchronization(uber_matrix, n, method)
sync_perm_indices = parse_sync_matrix(sync_matrix, n, symbols, combinations)
for source, target in [("a", "b"), ("a", "c"), ("b", "c")]:
perm_indices[source][target][p] = perm_matrix_to_perm_indices(sync_perm_indices[(source, target)])
perm_indices[target][source][p] = perm_indices[source][target][p].T
new_similarity = 0.0
for source, target in [("a", "b"), ("a", "c"), ("b", "c")]:
pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p])
new_similarity += pairwise_sim
pylogger.info(
f"Iteration {iteration}, Permutation {p}: {(new_similarity - old_similarity).sum(dim=-1).sum(dim=-1)}"
)
progress = progress or torch.any(new_similarity > old_similarity + 1e-12)
if not progress:
break
return perm_indices