Untitled
unknown
python
2 years ago
523 B
13
Indexable
def _get_encoder(M: int):
encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, M*2),
)
return encoder
def _get_decoder(M: int):
decoder_net = nn.Sequential(
nn.Linear(M, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Unflatten(-1, (28, 28))
)
return decoder_netEditor is loading...
Leave a Comment