Untitled
好的,我会在现有代码的基础上进行纠错、完善和补充EMD Loss的计算。以下是更新后的代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class SinkhornSim(nn.Module): def __init__(self, eps=1e-3, max_iter=100, reduction='sum'): super(SinkhornSim, self).__init__() self.eps = eps self.max_iter = max_iter self.reduction = reduction def forward(self, x, y): # Flatten spatial dimensions batch_size, num_channels, height, width = x.size() x = x.view(batch_size, num_channels, -1).permute(0, 2, 1) # Shape: (batch_size, height*width, num_channels) y = y.view(batch_size, num_channels, -1).permute(0, 2, 1) # Shape: (batch_size, height*width, num_channels) # Normalize the vectors x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) # Compute the cost matrix (cosine distance) cost_matrix = 1 - torch.bmm(x, y.transpose(1, 2)) # Shape: (batch_size, height*width, height*width) # Sinkhorn-Knopp initialization K = torch.exp(-cost_matrix / self.eps) u = torch.ones(batch_size, height * width).to(x.device) / (height * width) v = torch.ones(batch_size, height * width).to(x.device) / (height * width) # Sinkhorn iterations for _ in range(self.max_iter): u = 1.0 / (K.bmm(v.unsqueeze(-1)).squeeze(-1) + 1e-8) v = 1.0 / (K.transpose(1, 2).bmm(u.unsqueeze(-1)).squeeze(-1) + 1e-8) # Compute the optimal transport plan transport_plan = u.unsqueeze(-1) * K * v.unsqueeze(-2) distance = torch.sum(transport_plan * cost_matrix, dim=(1, 2)) if self.reduction == 'mean': distance = distance.mean() elif self.reduction == 'sum': distance = distance.sum() return distance class CosineSim(nn.Module): def __init__(self): super(CosineSim, self).__init__() def forward(self, x, y): x = x.view(x.size(0), -1) # Flatten y = y.view(y.size(0), -1) # Flatten return F.cosine_similarity(x, y).mean() class EMDLoss(nn.Module): def __init__(self): super(EMDLoss, self).__init__() self.cosine_sim = CosineSim() self.sinkhorn_sim = SinkhornSim() def forward(self, x, y): # Compute cosine similarity map m = self.cosine_sim(x, y) # Compute optimal transport plan using Sinkhorn algorithm pi = self.sinkhorn_sim(x, y) # Compute the EMD loss loss = 2 - 2 * (1 - m) * pi.mean() return loss if __name__ == "__main__": x1 = torch.randn(1, 1024, 7, 7) x2 = torch.randn(1, 1024, 7, 7) emd_loss = EMDLoss() loss = emd_loss(x1, x2) print('EMD Loss:', loss.item()) ``` ### 纠错和完善说明: 1. **扁平化和转置**:在 `SinkhornSim` 中,扁平化特征图的空间维度并调整维度以便矩阵乘法。 2. **归一化**:在计算成本矩阵之前,对特征向量进行归一化处理。 3. **成本矩阵**:使用余弦距离代替欧几里得距离。 4. **Sinkhorn-Knopp 迭代**:确保在每一步迭代中避免数值不稳定问题。 5. **计算传输计划**:根据 `u` 和 `v` 计算传输计划。 6. **EMD Loss**:在 `EMDLoss` 类中,利用 `CosineSim` 和 `SinkhornSim` 计算最终的 EMD Loss。 希望这段代码能符合你的需求。如果有任何问题或需要进一步的调整,请告诉我!
Leave a Comment