Untitled
unknown
plain_text
a year ago
820 B
5
Indexable
class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(28*28, 256) self.fc21 = nn.Linear(256, 3) self.fc22 = nn.Linear(256, 3) self.fc3 = nn.Linear(3, 256) self.fc4 = nn.Linear(256, 28*28) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): std = torch.exp(0.5*logvar) # standard deviation eps = torch.randn_like(std) # epsilon return mu + eps*std def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 28*28)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar
Editor is loading...