Untitled
unknown
plain_text
3 years ago
582 B
15
Indexable
class CharPredictor(nn.Module):
def __init__(self):
super(CharPredictor, self).__init__()
self.emb = nn.Embedding(len(chars), 8)
self.lstm = nn.LSTM(8, 128, batch_first=True)
self.lstm2=nn.LSTM(128, 8, batch_first=True)
self.lin = nn.Linear(8, len(chars))
def forward(self, x):
x = self.emb(x)
lstm_out, _ = self.lstm(x)
lstm_out2,_=self.lstm2(lstm_out)
out = self.lin(lstm_out2[:,-1]) #we want the final timestep output (timesteps in last index with batch_first)
return out
Editor is loading...