Untitled
unknown
plain_text
a year ago
3.5 kB
11
Indexable
好的,我会在现有代码的基础上进行纠错、完善和补充EMD Loss的计算。以下是更新后的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinkhornSim(nn.Module):
def __init__(self, eps=1e-3, max_iter=100, reduction='sum'):
super(SinkhornSim, self).__init__()
self.eps = eps
self.max_iter = max_iter
self.reduction = reduction
def forward(self, x, y):
# Flatten spatial dimensions
batch_size, num_channels, height, width = x.size()
x = x.view(batch_size, num_channels, -1).permute(0, 2, 1) # Shape: (batch_size, height*width, num_channels)
y = y.view(batch_size, num_channels, -1).permute(0, 2, 1) # Shape: (batch_size, height*width, num_channels)
# Normalize the vectors
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
# Compute the cost matrix (cosine distance)
cost_matrix = 1 - torch.bmm(x, y.transpose(1, 2)) # Shape: (batch_size, height*width, height*width)
# Sinkhorn-Knopp initialization
K = torch.exp(-cost_matrix / self.eps)
u = torch.ones(batch_size, height * width).to(x.device) / (height * width)
v = torch.ones(batch_size, height * width).to(x.device) / (height * width)
# Sinkhorn iterations
for _ in range(self.max_iter):
u = 1.0 / (K.bmm(v.unsqueeze(-1)).squeeze(-1) + 1e-8)
v = 1.0 / (K.transpose(1, 2).bmm(u.unsqueeze(-1)).squeeze(-1) + 1e-8)
# Compute the optimal transport plan
transport_plan = u.unsqueeze(-1) * K * v.unsqueeze(-2)
distance = torch.sum(transport_plan * cost_matrix, dim=(1, 2))
if self.reduction == 'mean':
distance = distance.mean()
elif self.reduction == 'sum':
distance = distance.sum()
return distance
class CosineSim(nn.Module):
def __init__(self):
super(CosineSim, self).__init__()
def forward(self, x, y):
x = x.view(x.size(0), -1) # Flatten
y = y.view(y.size(0), -1) # Flatten
return F.cosine_similarity(x, y).mean()
class EMDLoss(nn.Module):
def __init__(self):
super(EMDLoss, self).__init__()
self.cosine_sim = CosineSim()
self.sinkhorn_sim = SinkhornSim()
def forward(self, x, y):
# Compute cosine similarity map
m = self.cosine_sim(x, y)
# Compute optimal transport plan using Sinkhorn algorithm
pi = self.sinkhorn_sim(x, y)
# Compute the EMD loss
loss = 2 - 2 * (1 - m) * pi.mean()
return loss
if __name__ == "__main__":
x1 = torch.randn(1, 1024, 7, 7)
x2 = torch.randn(1, 1024, 7, 7)
emd_loss = EMDLoss()
loss = emd_loss(x1, x2)
print('EMD Loss:', loss.item())
```
### 纠错和完善说明:
1. **扁平化和转置**:在 `SinkhornSim` 中,扁平化特征图的空间维度并调整维度以便矩阵乘法。
2. **归一化**:在计算成本矩阵之前,对特征向量进行归一化处理。
3. **成本矩阵**:使用余弦距离代替欧几里得距离。
4. **Sinkhorn-Knopp 迭代**:确保在每一步迭代中避免数值不稳定问题。
5. **计算传输计划**:根据 `u` 和 `v` 计算传输计划。
6. **EMD Loss**:在 `EMDLoss` 类中,利用 `CosineSim` 和 `SinkhornSim` 计算最终的 EMD Loss。
希望这段代码能符合你的需求。如果有任何问题或需要进一步的调整,请告诉我!Editor is loading...
Leave a Comment