Untitled
unknown
python
2 years ago
2.2 kB
5
Indexable
import torch
import torch.nn as nn
import torch.distributions as td
import torch.nn.functional as F
class DDPM(nn.Module):
def __init__(self, network, beta_1=1e-4, beta_T=2e-2, T=100):
super(DDPM, self).__init__()
self.network = network
self.beta_1 = beta_1
self.beta_T = beta_T
self.T = T
self.beta = nn.Parameter(torch.linspace(beta_1, beta_T, T), requires_grad=False)
self.alpha = nn.Parameter(1 - self.beta, requires_grad=False)
self.alpha_cumprod = nn.Parameter(self.alpha.cumprod(dim=0), requires_grad=False)
self.sigma2 = self.beta # Assuming σ²_t = β_t for simplicity
def negative_elbo(self, x):
batch_size, D = x.shape
t = torch.randint(0, self.T, (batch_size,), device=x.device)
noise = torch.randn_like(x)
x_t = torch.sqrt(self.alpha_cumprod[t]) * x + torch.sqrt(1 - self.alpha_cumprod[t]) * noise
# Compute the predicted noise and calculate loss
noise_pred = self.network(x_t, t.unsqueeze(-1).float())
neg_elbo = (noise - noise_pred).pow(2).mean()
return neg_elbo
def sample(self, shape):
device = self.alpha.device
x_t = torch.randn(shape, device=device)
for t in reversed(range(0, self.T)):
alpha_t = self.alpha[t]
alpha_cumprod_t = self.alpha_cumprod[t]
sigma2_t = self.sigma2[t]
if t > 0:
# Compute the noise used for reparameterization
z = torch.randn_like(x_t)
else:
z = 0 # No noise needed for the last step
# Compute the mean for the reverse process
x_prev_mean = (x_t - (1 - alpha_t).sqrt() * self.network(x_t, torch.tensor([t]).float().to(device))) / alpha_t.sqrt()
# Compute the standard deviation for the reverse process
sigma_t = (sigma2_t * (1 - self.alpha_cumprod[t-1]) / (1 - self.alpha_cumprod[t])).sqrt()
# Sample x_{t-1}
x_t = x_prev_mean + sigma_t * z
return x_t
def loss(self, x):
return self.negative_elbo(x).mean()Editor is loading...
Leave a Comment