Untitled

mail@pastecode.io avatar
unknown
python
a month ago
2.2 kB
2
Indexable
Never
import torch
import torch.nn as nn
import torch.distributions as td
import torch.nn.functional as F

class DDPM(nn.Module):
    def __init__(self, network, beta_1=1e-4, beta_T=2e-2, T=100):
        super(DDPM, self).__init__()
        self.network = network
        self.beta_1 = beta_1
        self.beta_T = beta_T
        self.T = T

        self.beta = nn.Parameter(torch.linspace(beta_1, beta_T, T), requires_grad=False)
        self.alpha = nn.Parameter(1 - self.beta, requires_grad=False)
        self.alpha_cumprod = nn.Parameter(self.alpha.cumprod(dim=0), requires_grad=False)
        self.sigma2 = self.beta  # Assuming σ²_t = β_t for simplicity

    def negative_elbo(self, x):
        batch_size, D = x.shape
        t = torch.randint(0, self.T, (batch_size,), device=x.device)
        noise = torch.randn_like(x)
        x_t = torch.sqrt(self.alpha_cumprod[t]) * x + torch.sqrt(1 - self.alpha_cumprod[t]) * noise

        # Compute the predicted noise and calculate loss
        noise_pred = self.network(x_t, t.unsqueeze(-1).float())
        neg_elbo = (noise - noise_pred).pow(2).mean()

        return neg_elbo

    def sample(self, shape):
        device = self.alpha.device
        x_t = torch.randn(shape, device=device)

        for t in reversed(range(0, self.T)):
            alpha_t = self.alpha[t]
            alpha_cumprod_t = self.alpha_cumprod[t]
            sigma2_t = self.sigma2[t]

            if t > 0:
                # Compute the noise used for reparameterization
                z = torch.randn_like(x_t)
            else:
                z = 0  # No noise needed for the last step

            # Compute the mean for the reverse process
            x_prev_mean = (x_t - (1 - alpha_t).sqrt() * self.network(x_t, torch.tensor([t]).float().to(device))) / alpha_t.sqrt()

            # Compute the standard deviation for the reverse process
            sigma_t = (sigma2_t * (1 - self.alpha_cumprod[t-1]) / (1 - self.alpha_cumprod[t])).sqrt()

            # Sample x_{t-1}
            x_t = x_prev_mean + sigma_t * z

        return x_t

    def loss(self, x):
        return self.negative_elbo(x).mean()
Leave a Comment