Untitled
unknown
python
2 years ago
794 B
4
Indexable
def get_loss(self, batch, batch_idx):
"""
Corresponds to Algorithm 1 from (Ho et al., 2020).
"""
ts = torch.randint(0, self.t_range, [batch.shape[0]], device=self.device)
noise_imgs = []
epsilons = torch.randn(batch.shape, device=self.device)
for i in range(len(ts)):
a_hat = self.alpha_bar(ts[i])
noise_imgs.append(
(math.sqrt(a_hat) * batch[i]) + (math.sqrt(1 - a_hat) * epsilons[i])
)
noise_imgs = torch.stack(noise_imgs, dim=0)
e_hat = self.forward(noise_imgs, ts.unsqueeze(-1).type(torch.float))
loss = nn.functional.mse_loss(
e_hat.reshape(-1, self.in_size), epsilons.reshape(-1, self.in_size)
)
return lossEditor is loading...
Leave a Comment