Untitled
unknown
plain_text
2 years ago
820 B
13
Indexable
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc21 = nn.Linear(256, 3)
self.fc22 = nn.Linear(256, 3)
self.fc3 = nn.Linear(3, 256)
self.fc4 = nn.Linear(256, 28*28)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar) # standard deviation
eps = torch.randn_like(std) # epsilon
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 28*28))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvarEditor is loading...