Custom LSTM

mail@pastecode.io avatar
unknown
python
2 years ago
1.9 kB
2
Indexable
Never
class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz, batch_size):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(batch_size, input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(batch_size, hidden_sz, hidden_sz * 4))
        self.C = nn.Parameter(torch.Tensor(batch_size, hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(batch_size, 1, hidden_sz * 4))
        self.init_weights()


    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)


    def forward(self, x, init_states, encoder_hidden):
    """x = (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        HS = self.hidden_size
        e_h = encoder_hidden[0]
        hidden_seq = []

        if init_states is None:
            h_t, c_t = (torch.zeros(bs, 1, self.hidden_size).to(x.device), torch.zeros(bs, 1, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        for t in range(seq_sz):
            x_t = x[:, t, :].unsqueeze(1)
            gates = torch.bmm(x_t, self.W) + torch.bmm(h_t, self.U) + torch.bmm(e_h, self.C) + self.bias

            i_t, f_t, g_t, o_t = (
                    torch.sigmoid(gates[:, :, :HS]), # input
                    torch.sigmoid(gates[:, :, HS:HS*2]), # forget
                    torch.tanh(gates[:, :, HS*2:HS*3]), # candidate cell state
                    torch.sigmoid(gates[:, :, HS*3:]), # output
            )

            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

            hidden_seq.append(h_t)
        hidden_seq = torch.cat(hidden_seq, dim=0)

        return hidden_seq.squeeze(1), (h_t.permute(1, 0, 2), c_t.permute(1, 0, 2))