Untitled
unknown
python
a year ago
523 B
10
Indexable
def _get_encoder(M: int): encoder = nn.Sequential( nn.Flatten(), nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, M*2), ) return encoder def _get_decoder(M: int): decoder_net = nn.Sequential( nn.Linear(M, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 784), nn.Unflatten(-1, (28, 28)) ) return decoder_net
Editor is loading...
Leave a Comment