Untitled

mail@pastecode.io avatar
unknown
plain_text
6 months ago
820 B
3
Indexable
Never
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc21 = nn.Linear(256, 3)
        self.fc22 = nn.Linear(256, 3)
        self.fc3 = nn.Linear(3, 256)
        self.fc4 = nn.Linear(256, 28*28)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar) # standard deviation
        eps = torch.randn_like(std) # epsilon
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 28*28))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar