Untitled
unknown
python
3 years ago
1.6 kB
19
Indexable
#Use the custom embeddings embedding = nn.Embedding.from_pretrained(torch.tensor(embeddings), freeze=False) embedding.to(device) class GlobalAttention(nn.Module): def __init__(self): super(GlobalAttention, self).__init__() self.Linear = nn.Linear(50,1) self.Value = nn.Linear(50,50) self.q = nn.Linear(50,1) def forward(self, x, plot=False): x = embedding(x) prob = F.softmax(self.q(torch.mean(x, axis=0).to(torch.float)).flatten()) t = torch.sum(prob.view(-1,1) * self.Value(x.to(torch.float)), axis=0).view(1,50) return torch.sigmoid(self.Linear(t.to(torch.float))) model = GlobalAttention().to(device) optim = torch.optim.Adam(params=model.parameters(), lr=3e-3) EPOCHS = 4 #Use the custom embeddings # x, l = next(iter(train_loader)) # print(x.shape) # print(embed(x).shape) for epoch in range(EPOCHS): print(f"------- EPOCH {epoch}-------") loss_train = 0 correct = 0 for batch, (x,y) in enumerate(train_loader): optim.zero_grad() x = torch.tensor(x).to(device) y = torch.tensor(y).to(device).view(-1).to(torch.float) y_hat = model(x).view(-1) loss = F.binary_cross_entropy(y_hat, y) loss.backward() optim.step() correct += 1 if torch.round(y_hat.view(1)) == y.view(1) else 0 loss_train += loss.detach() loss_train = loss_train/(batch+1) accuracy_train = correct/(batch+1) print(f"accuracy: {accuracy_train}, loss: {loss_train}")
Editor is loading...