Untitled

 avatar
unknown
plain_text
a year ago
2.5 kB
5
Indexable
def compute_emd(x, y):
    # 计算两个分布之间的EMD
    # x, y 应该是两个向量或直方图,代表分布
    x = np.array(x)
    y = np.array(y)
    cost_matrix = np.abs(x[:, np.newaxis] - y[np.newaxis, :])
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    emd_distance = cost_matrix[row_ind, col_ind].sum()
    return emd_distance

def run_kmeans_emd(x, args):
    print('performing kmeans clustering with EMD distance')
    results = {'im2cluster': [], 'centroids': [], 'density': []}

    for seed, num_cluster in enumerate(args.num_cluster):
        d = x.shape[1]
        k = int(num_cluster)
        clus = faiss.Clustering(d, k)
        clus.verbose = True
        clus.niter = 20
        clus.nredo = 5
        clus.seed = seed
        clus.max_points_per_centroid = 1000
        clus.min_points_per_centroid = 10

        res = faiss.StandardGpuResources()
        cfg = faiss.GpuIndexFlatConfig()
        cfg.useFloat16 = False
        cfg.device = args.gpu
        index = faiss.GpuIndexFlatL2(res, d, cfg)

        clus.train(x, index)

        # 手动计算 EMD 距离并重新分配样本到最近的质心
        centroids = faiss.vector_to_array(clus.centroids).reshape(k, d)
        D = np.zeros((x.shape[0], k))

        for i in range(x.shape[0]):
            for j in range(k):
                D[i, j] = compute_emd(x[i], centroids[j])

        I = np.argmin(D, axis=1)
        im2cluster = [int(n) for n in I]

        Dcluster = [[] for _ in range(k)]
        for im, i in enumerate(im2cluster):
            Dcluster[i].append(D[im][i])

        density = np.zeros(k)
        for i, dist in enumerate(Dcluster):
            if len(dist) > 1:
                d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
                density[i] = d

        dmax = density.max()
        for i, dist in enumerate(Dcluster):
            if len(dist) <= 1:
                density[i] = dmax

        density = density.clip(np.percentile(density, 10), np.percentile(density, 90))
        density = args.temperature * density / density.mean()

        centroids = torch.Tensor(centroids).cuda()
        centroids = F.normalize(centroids, p=2, dim=1)

        im2cluster = torch.LongTensor(im2cluster).cuda()
        density = torch.Tensor(density).cuda()

        results['centroids'].append(centroids)
        results['density'].append(density)
        results['im2cluster'].append(im2cluster)

    return results
Editor is loading...
Leave a Comment