Untitled
unknown
python
a year ago
794 B
3
Indexable
def get_loss(self, batch, batch_idx): """ Corresponds to Algorithm 1 from (Ho et al., 2020). """ ts = torch.randint(0, self.t_range, [batch.shape[0]], device=self.device) noise_imgs = [] epsilons = torch.randn(batch.shape, device=self.device) for i in range(len(ts)): a_hat = self.alpha_bar(ts[i]) noise_imgs.append( (math.sqrt(a_hat) * batch[i]) + (math.sqrt(1 - a_hat) * epsilons[i]) ) noise_imgs = torch.stack(noise_imgs, dim=0) e_hat = self.forward(noise_imgs, ts.unsqueeze(-1).type(torch.float)) loss = nn.functional.mse_loss( e_hat.reshape(-1, self.in_size), epsilons.reshape(-1, self.in_size) ) return loss
Editor is loading...
Leave a Comment