Untitled
unknown
plain_text
2 years ago
582 B
7
Indexable
class CharPredictor(nn.Module): def __init__(self): super(CharPredictor, self).__init__() self.emb = nn.Embedding(len(chars), 8) self.lstm = nn.LSTM(8, 128, batch_first=True) self.lstm2=nn.LSTM(128, 8, batch_first=True) self.lin = nn.Linear(8, len(chars)) def forward(self, x): x = self.emb(x) lstm_out, _ = self.lstm(x) lstm_out2,_=self.lstm2(lstm_out) out = self.lin(lstm_out2[:,-1]) #we want the final timestep output (timesteps in last index with batch_first) return out
Editor is loading...