Untitled
unknown
python
4 years ago
1.6 kB
21
Indexable
#Use the custom embeddings
embedding = nn.Embedding.from_pretrained(torch.tensor(embeddings), freeze=False)
embedding.to(device)
class GlobalAttention(nn.Module):
def __init__(self):
super(GlobalAttention, self).__init__()
self.Linear = nn.Linear(50,1)
self.Value = nn.Linear(50,50)
self.q = nn.Linear(50,1)
def forward(self, x, plot=False):
x = embedding(x)
prob = F.softmax(self.q(torch.mean(x, axis=0).to(torch.float)).flatten())
t = torch.sum(prob.view(-1,1) * self.Value(x.to(torch.float)), axis=0).view(1,50)
return torch.sigmoid(self.Linear(t.to(torch.float)))
model = GlobalAttention().to(device)
optim = torch.optim.Adam(params=model.parameters(), lr=3e-3)
EPOCHS = 4
#Use the custom embeddings
# x, l = next(iter(train_loader))
# print(x.shape)
# print(embed(x).shape)
for epoch in range(EPOCHS):
print(f"------- EPOCH {epoch}-------")
loss_train = 0
correct = 0
for batch, (x,y) in enumerate(train_loader):
optim.zero_grad()
x = torch.tensor(x).to(device)
y = torch.tensor(y).to(device).view(-1).to(torch.float)
y_hat = model(x).view(-1)
loss = F.binary_cross_entropy(y_hat, y)
loss.backward()
optim.step()
correct += 1 if torch.round(y_hat.view(1)) == y.view(1) else 0
loss_train += loss.detach()
loss_train = loss_train/(batch+1)
accuracy_train = correct/(batch+1)
print(f"accuracy: {accuracy_train}, loss: {loss_train}")Editor is loading...