Untitled

mail@pastecode.io avatar
unknown
plain_text
5 months ago
3.5 kB
2
Indexable
好的,我会在现有代码的基础上进行纠错、完善和补充EMD Loss的计算。以下是更新后的代码:

```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SinkhornSim(nn.Module):
    def __init__(self, eps=1e-3, max_iter=100, reduction='sum'):
        super(SinkhornSim, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # Flatten spatial dimensions
        batch_size, num_channels, height, width = x.size()
        x = x.view(batch_size, num_channels, -1).permute(0, 2, 1)  # Shape: (batch_size, height*width, num_channels)
        y = y.view(batch_size, num_channels, -1).permute(0, 2, 1)  # Shape: (batch_size, height*width, num_channels)

        # Normalize the vectors
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)

        # Compute the cost matrix (cosine distance)
        cost_matrix = 1 - torch.bmm(x, y.transpose(1, 2))  # Shape: (batch_size, height*width, height*width)

        # Sinkhorn-Knopp initialization
        K = torch.exp(-cost_matrix / self.eps)
        u = torch.ones(batch_size, height * width).to(x.device) / (height * width)
        v = torch.ones(batch_size, height * width).to(x.device) / (height * width)

        # Sinkhorn iterations
        for _ in range(self.max_iter):
            u = 1.0 / (K.bmm(v.unsqueeze(-1)).squeeze(-1) + 1e-8)
            v = 1.0 / (K.transpose(1, 2).bmm(u.unsqueeze(-1)).squeeze(-1) + 1e-8)

        # Compute the optimal transport plan
        transport_plan = u.unsqueeze(-1) * K * v.unsqueeze(-2)
        distance = torch.sum(transport_plan * cost_matrix, dim=(1, 2))

        if self.reduction == 'mean':
            distance = distance.mean()
        elif self.reduction == 'sum':
            distance = distance.sum()

        return distance

class CosineSim(nn.Module):
    def __init__(self):
        super(CosineSim, self).__init__()

    def forward(self, x, y):
        x = x.view(x.size(0), -1)  # Flatten
        y = y.view(y.size(0), -1)  # Flatten
        return F.cosine_similarity(x, y).mean()

class EMDLoss(nn.Module):
    def __init__(self):
        super(EMDLoss, self).__init__()
        self.cosine_sim = CosineSim()
        self.sinkhorn_sim = SinkhornSim()

    def forward(self, x, y):
        # Compute cosine similarity map
        m = self.cosine_sim(x, y)

        # Compute optimal transport plan using Sinkhorn algorithm
        pi = self.sinkhorn_sim(x, y)

        # Compute the EMD loss
        loss = 2 - 2 * (1 - m) * pi.mean()
        return loss

if __name__ == "__main__":
    x1 = torch.randn(1, 1024, 7, 7)
    x2 = torch.randn(1, 1024, 7, 7)
    emd_loss = EMDLoss()
    loss = emd_loss(x1, x2)
    print('EMD Loss:', loss.item())
```

### 纠错和完善说明:
1. **扁平化和转置**:在 `SinkhornSim` 中,扁平化特征图的空间维度并调整维度以便矩阵乘法。
2. **归一化**:在计算成本矩阵之前,对特征向量进行归一化处理。
3. **成本矩阵**:使用余弦距离代替欧几里得距离。
4. **Sinkhorn-Knopp 迭代**:确保在每一步迭代中避免数值不稳定问题。
5. **计算传输计划**:根据 `u` 和 `v` 计算传输计划。
6. **EMD Loss**:在 `EMDLoss` 类中,利用 `CosineSim` 和 `SinkhornSim` 计算最终的 EMD Loss。

希望这段代码能符合你的需求。如果有任何问题或需要进一步的调整,请告诉我!
Leave a Comment