Untitled

mail@pastecode.io avatar
unknown
python
2 months ago
523 B
6
Indexable
Never
def _get_encoder(M: int):
    encoder =   nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, M*2),
        )
    return encoder

def _get_decoder(M: int):
    decoder_net = nn.Sequential(
        nn.Linear(M, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 784),
        nn.Unflatten(-1, (28, 28))
    )
    return decoder_net
Leave a Comment