Untitled
unknown
plain_text
3 years ago
2.0 kB
8
Indexable
bar = tqdm(train_loader, total = len(train_loader))
for i, ((old_node_token_ids, old_node_type_ids), (new_node_token_ids, new_node_type_ids), old_token_ids, batch_joint_graph_dgl, new_graph_dgl, batch_dual_graph_dgl, changed_node_ids, batch_diff_token_ids, metadata) in enumerate(bar):
batch_diff_token_ids = batch_diff_token_ids.long()
tgt_input = batch_diff_token_ids[:, :-1]
tgt_label = batch_diff_token_ids[:, 1:]
old_token_ids = torch.split(old_token_ids, metadata['old_token_num'])
pad_old_token_ids = list(map(lambda x: x.numpy().tolist(), old_token_ids))
pad_old_token_ids = _pad_batch_2D(pad_old_token_ids)
pad_old_token_ids = torch.Tensor(pad_old_token_ids).long().to(device)
# new_token_ids = batch_diff_token_ids
# batch_diff_token_ids = torch.Tensor(batch_diff_token_ids).long().to(device)
# tgt_input = batch_diff_token_ids[:, :-1]
# tgt_label = batch_diff_token_ids[:, 1:]
old_token_embed = comment_embedding(pad_old_token_ids)
opt.zero_grad()
memory = encoder(old_token_embed, src_key_padding_mask = pad_old_token_ids == 0)
tgt_token_embed = comment_embedding(tgt_input)
tgt_mask = generate_square_subsequent_mask(tgt_token_embed.shape[1]).to(device)
# memory = encoder(old_token_embed, src_key_padding_mask = pad_old_token_ids == 0)
output = decoder(tgt_token_embed, memory, tgt_mask = tgt_mask, memory_key_padding_mask = pad_old_token_ids == 0, tgt_key_padding_mask = tgt_input == 0)
# output = model(old_token_embed, tgt_token_embed, src_mask = src_mask, src_key_padding_mask = old_token_ids == 0, tgt_mask = tgt_mask, tgt_key_padding_mask = tgt_input == 0)
pred = fc_layer(output).transpose(1, 2) # B, L, V -> B, V, L
loss = nn.CrossEntropyLoss(ignore_index = 0)(pred,tgt_label)
loss.backward()
opt.step()
count += 1
acc_loss += loss.cpu().item()
val_loss = eval(model, valid_loader)
print('train_loss', acc_loss / count, 'val loss', val_loss)Editor is loading...