Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
1.6 kB
16
Indexable
Never
#Use the custom embeddings
embedding = nn.Embedding.from_pretrained(torch.tensor(embeddings), freeze=False)
embedding.to(device)

class GlobalAttention(nn.Module):

    def __init__(self):
        super(GlobalAttention, self).__init__()
        self.Linear = nn.Linear(50,1)
        self.Value = nn.Linear(50,50)
        self.q = nn.Linear(50,1)

    def forward(self, x, plot=False):
        x = embedding(x)
        prob = F.softmax(self.q(torch.mean(x, axis=0).to(torch.float)).flatten())
        t = torch.sum(prob.view(-1,1) * self.Value(x.to(torch.float)), axis=0).view(1,50)
        return torch.sigmoid(self.Linear(t.to(torch.float)))

model = GlobalAttention().to(device)

optim = torch.optim.Adam(params=model.parameters(), lr=3e-3)

EPOCHS = 4

#Use the custom embeddings


# x, l = next(iter(train_loader))
# print(x.shape)
# print(embed(x).shape)
for epoch in range(EPOCHS):
    print(f"------- EPOCH {epoch}-------")
    loss_train = 0
    correct = 0
    for batch, (x,y) in enumerate(train_loader):
        optim.zero_grad()

        x = torch.tensor(x).to(device)
        y = torch.tensor(y).to(device).view(-1).to(torch.float)
        
        y_hat = model(x).view(-1)
        loss = F.binary_cross_entropy(y_hat, y)

        loss.backward()
        optim.step()
        
        correct += 1 if torch.round(y_hat.view(1)) == y.view(1) else 0

        loss_train += loss.detach()

    loss_train = loss_train/(batch+1)
    accuracy_train = correct/(batch+1)
    print(f"accuracy: {accuracy_train}, loss: {loss_train}")