Untitled

 avatar
unknown
plain_text
2 years ago
1.6 kB
3
Indexable
acc_loss = 0
count = 0
bar = tqdm(train_loader, total = len(train_loader))
for i, ((old_node_token_ids, old_node_type_ids), (new_node_token_ids, new_node_type_ids), old_token_ids, batch_joint_graph_dgl, new_graph_dgl, dual_graph_dgl, changed_node_ids, batch_diff_token_ids, metadata) in enumerate(bar):
   
  
    batch_diff_token_ids = batch_diff_token_ids.long()
    tgt_input = batch_diff_token_ids[:, :-1]
    tgt_label = batch_diff_token_ids[:, 1:]

    old_token_ids = torch.split(old_token_ids, metadata['old_token_num'])
    
    pad_old_token_ids = list(map(lambda x: x.numpy().tolist(), old_token_ids))
    pad_old_token_ids = _pad_batch_2D(pad_old_token_ids)
    pad_old_token_ids = torch.Tensor(pad_old_token_ids).long()
    # pad_old_token_ids = pad_tensor2d(old_token_ids_coms) #nn.utils.rnn.pad_sequence(old_token_ids_coms, batch_first = True, padding_value = PAD_IDX).unsqueeze(1)

  

    initial_tar_embeds = comment_embedding(tgt_input)
    tgt_mask = generate_square_subsequent_mask(initial_tar_embeds.shape[1])#.to(old_token_mask.device)
    # src_mask = torch.zeros(pad_fused_src_embeds.shape[1], pad_fused_src_embeds.shape[1]).to(old_token_mask.device)
    tar_embedding = decoder(initial_tar_embeds, memory, tgt_mask = tgt_mask, memory_key_padding_mask = pad_old_token_ids == 0, tgt_key_padding_mask = tgt_input == 0)

  


    loss = nn.CrossEntropyLoss(ignore_index = 0)(tar_output_gen, tgt_label)
    loss.backward()
    opt.step()
    count += 1
    acc_loss += loss.cpu().item()
val_loss = eval(model, valid_loader)
print('train_loss', acc_loss / len(train_loader), 'val loss', val_loss)
Editor is loading...