Untitled
unknown
python
a year ago
4.4 kB
2
Indexable
Never
def synchronized_weight_matching(models, ps: PermutationSpec, method, symbols, combinations, max_iter=100): """ Find a permutation of params_b to make them match params_a. :param ps: PermutationSpec :param target: the parameters to match :param to_permute: the parameters to permute """ params = {s: m.model.state_dict() for s, m in models.items()} # For a MLP of 4 layers it would be something like {'P_0': 512, 'P_1': 512, 'P_2': 512, 'P_3': 256}. Input and output dim are never permuted. perm_sizes = {p: params["a"][axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()} # {'a': {'b': 'P0': P_AB_0, 'c': ....}, .... } perm_indices = { symb: { other_symb: {p: torch.arange(n) for p, n in perm_sizes.items()} for other_symb in symbols.difference(symb) } for symb in symbols } # e.g. P0, P1, .. perm_names = list(perm_indices["a"]["b"].keys()) for iteration in tqdm(range(max_iter), desc="Weight matching"): progress = False # iterate over the permutation matrices in random order for p_ix in torch.randperm(len(perm_names)): p = perm_names[p_ix] n = perm_sizes[p] similarities = { symb: {other_symb: torch.zeros((n, n)) for other_symb in symbols.difference(symb)} for symb in symbols } # all the params that are permuted by this permutation matrix, together with the axis on which it acts # e.g. ('layer_0.weight', 0), ('layer_0.bias', 0), ('layer_1.weight', 0).. params_and_axes = ps.perm_to_axes[p] for params_name, axis in params_and_axes: w_a = params["a"][params_name] w_b = params["b"][params_name] w_c = params["c"][params_name] assert w_a.shape == w_b.shape perms_to_apply = ps.axes_to_perm[params_name] w_b_a = get_permuted_param(w_b, perms_to_apply, perm_indices["b"]["a"], except_axis=axis) w_c_a = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["a"], except_axis=axis) w_c_b = get_permuted_param(w_c, perms_to_apply, perm_indices["c"]["b"], except_axis=axis) w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)) w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)) w_b_a = torch.moveaxis(w_b_a, axis, 0).reshape((n, -1)) w_c_a = torch.moveaxis(w_c_a, axis, 0).reshape((n, -1)) w_c_b = torch.moveaxis(w_c_b, axis, 0).reshape((n, -1)) similarities["a"]["b"] = w_a @ w_b_a.T similarities["a"]["c"] = w_a @ w_c_a.T similarities["b"]["c"] = w_b @ w_c_b.T similarities["b"]["a"] = similarities["a"]["b"].T similarities["c"]["a"] = similarities["a"]["c"].T similarities["c"]["b"] = similarities["b"]["c"].T old_similarity = 0.0 for source, target in [("a", "b"), ("a", "c"), ("c", "b")]: pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p]) old_similarity += pairwise_sim uber_matrix = three_models_uber_matrix( similarities["a"]["b"], similarities["a"]["c"], similarities["b"]["c"], perm_dim=n ) sync_matrix = optimize_synchronization(uber_matrix, n, method) sync_perm_indices = parse_sync_matrix(sync_matrix, n, symbols, combinations) for source, target in [("a", "b"), ("a", "c"), ("b", "c")]: perm_indices[source][target][p] = perm_matrix_to_perm_indices(sync_perm_indices[(source, target)]) perm_indices[target][source][p] = perm_indices[source][target][p].T new_similarity = 0.0 for source, target in [("a", "b"), ("a", "c"), ("b", "c")]: pairwise_sim = compute_weights_similarity(similarities[target][source], perm_indices[target][source][p]) new_similarity += pairwise_sim pylogger.info( f"Iteration {iteration}, Permutation {p}: {(new_similarity - old_similarity).sum(dim=-1).sum(dim=-1)}" ) progress = progress or torch.any(new_similarity > old_similarity + 1e-12) if not progress: break return perm_indices