Untitled
unknown
python
a year ago
2.2 kB
4
Indexable
import torch import torch.nn as nn import torch.distributions as td import torch.nn.functional as F class DDPM(nn.Module): def __init__(self, network, beta_1=1e-4, beta_T=2e-2, T=100): super(DDPM, self).__init__() self.network = network self.beta_1 = beta_1 self.beta_T = beta_T self.T = T self.beta = nn.Parameter(torch.linspace(beta_1, beta_T, T), requires_grad=False) self.alpha = nn.Parameter(1 - self.beta, requires_grad=False) self.alpha_cumprod = nn.Parameter(self.alpha.cumprod(dim=0), requires_grad=False) self.sigma2 = self.beta # Assuming σ²_t = β_t for simplicity def negative_elbo(self, x): batch_size, D = x.shape t = torch.randint(0, self.T, (batch_size,), device=x.device) noise = torch.randn_like(x) x_t = torch.sqrt(self.alpha_cumprod[t]) * x + torch.sqrt(1 - self.alpha_cumprod[t]) * noise # Compute the predicted noise and calculate loss noise_pred = self.network(x_t, t.unsqueeze(-1).float()) neg_elbo = (noise - noise_pred).pow(2).mean() return neg_elbo def sample(self, shape): device = self.alpha.device x_t = torch.randn(shape, device=device) for t in reversed(range(0, self.T)): alpha_t = self.alpha[t] alpha_cumprod_t = self.alpha_cumprod[t] sigma2_t = self.sigma2[t] if t > 0: # Compute the noise used for reparameterization z = torch.randn_like(x_t) else: z = 0 # No noise needed for the last step # Compute the mean for the reverse process x_prev_mean = (x_t - (1 - alpha_t).sqrt() * self.network(x_t, torch.tensor([t]).float().to(device))) / alpha_t.sqrt() # Compute the standard deviation for the reverse process sigma_t = (sigma2_t * (1 - self.alpha_cumprod[t-1]) / (1 - self.alpha_cumprod[t])).sqrt() # Sample x_{t-1} x_t = x_prev_mean + sigma_t * z return x_t def loss(self, x): return self.negative_elbo(x).mean()
Editor is loading...
Leave a Comment