Untitled

 avatar
unknown
plain_text
2 years ago
582 B
7
Indexable
class CharPredictor(nn.Module):
    def __init__(self):
        super(CharPredictor, self).__init__()
        self.emb = nn.Embedding(len(chars), 8)
        self.lstm = nn.LSTM(8, 128, batch_first=True)
        self.lstm2=nn.LSTM(128, 8, batch_first=True)
        self.lin = nn.Linear(8, len(chars))

    def forward(self, x):
        x = self.emb(x)
        lstm_out, _ = self.lstm(x)
        lstm_out2,_=self.lstm2(lstm_out)
        out = self.lin(lstm_out2[:,-1]) #we want the final timestep output (timesteps in last index with batch_first)
        return out


Editor is loading...