Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
5.5 kB
1
Indexable
Never
def _process_batch(self, inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, sent_len, sent_num, y_lengths):
    """
    处理每个 batch 数据
    :param inp_x: 输入的对话数据
    :param inp_sim: k 个和 n-th utterance 相似的其他 utterance
    :param out_sim: n-th utterance 分类的 ground truth 标签
    :param inp_y: 输入的 n-th utterance
    :param out_y: 期望输出的 n-th utterance (ground truth)
    :param sim_len: 每个相似 utterance 的实际句长
    :param sent_len: 对话中每个句子的实际长度
    :param sent_num: 对话中 utterance 的实际数量
    :param y_lengths: n-th utterance 的实际长度
    :return: losses, batch_outputs,一个 loss list 和该 batch 对应的输出,batch_outputs 是一个字典类型的数据
    """
    self.model.train()

    # 获得模型输出
    outputs = self.model(self.config["model"]["k"], inp_x, inp_sim, inp_y, sim_len, sent_len, sent_num, y_lengths)
    sent_prob, outs_enc, outs_enc_filter, dec1, dec2, sent_len, dialogue_pre, summary_pre = outputs
    batch_outputs = {"model_outputs": outputs}

    # --------------------------------------------------------------
    # 1 - 预测 n-th utterance 损失计算,基于全部 dialogue 生成
    # --------------------------------------------------------------
    _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1))
    _x_labels = out_y.contiguous().view(-1)
    nsent_loss = F.cross_entropy(_dec1_logits, _x_labels, ignore_index=0, reduction='none')

    nsent_loss_token = nsent_loss.view(out_y.size())
    batch_outputs["n_sent"] = nsent_loss_token
    mean_rec_loss = nsent_loss.sum() / y_lengths.float().sum()
    losses = [mean_rec_loss]

    # --------------------------------------------------------------
    # 1.5 - 从摘要预测 n-th utterance 损失计算
    # --------------------------------------------------------------
    if self.config["model"]["n_sent_sum_loss"]:
        _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1))
        nsent_loss_sum = F.cross_entropy(_dec2_logits, _x_labels, ignore_index=0, reduction='none')
        nsent_loss_token_sum = nsent_loss_sum.view(out_y.size())
        batch_outputs["n_sent_sum"] = nsent_loss_token_sum
        mean_rec_sum_loss = nsent_loss_sum.sum() / y_lengths.float().sum()
        losses.append(mean_rec_sum_loss)
    else:
        mean_rec_sum_loss = None

    # --------------------------------------------------------------
    # 2 - KL 散度计算,让两个解码器分布相似
    # --------------------------------------------------------------
    if self.config["model"]["doc_sum_kl_loss"]:
        _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1))
        _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1))
        #kl_loss = torch.nn.functional.kl_div(_dec2_logits, _dec2_logits, size_average=None, reduce=True, reduction='mean')
        kl_loss = kl_categorical(_dec1_logits, _dec2_logits)
        losses.append(kl_loss)
    else:
        kl_loss = None

    # --------------------------------------------------------------
    # 3 - 长度损失
    # --------------------------------------------------------------
    if self.config["model"]["length_loss"]:
        _, topk_indices = torch.topk(sent_prob, k=self.config["model"]["k"], dim=1)
        topk_indices = torch.squeeze(topk_indices, -1)
        sum_length = torch.gather(sent_len, dim=1, index=topk_indices)
        sum_length = torch.sum(sum_length, dim=1)
        tmp = torch.sub(self.config["data"]["ext_sum_len"], sum_length)
        length_loss = torch.mean(tmp.float())
        losses.append(length_loss)
    else:
        length_loss = None

    # --------------------------------------------------------------
    # 4 - 文档摘要的 representation 和文档的 representation 尽量相近
    # --------------------------------------------------------------
    if self.config["model"]["doc_sum_sim_loss"]:
        dialog_rep = torch.squeeze(torch.sum(outs_enc, 1))
        summary_rep = torch.squeeze(torch.sum(outs_enc_filter, 1))
        sim_loss = pairwise_loss(dialog_rep, summary_rep)
        losses.append(sim_loss)
    else:
        sim_loss = None

    # --------------------------------------------------------------
    # 4 - n-th utterance 分类任务
    # --------------------------------------------------------------
    if self.config["model"]["nsent_classification"]:
        criterion = torch.nn.BCEWithLogitsLoss()
        dia_pre_loss = criterion(dialogue_pre, out_sim.float())
        sum_pre_loss = criterion(summary_pre, out_sim.float())
        pre_kl_loss = kl_categorical(dialogue_pre, summary_pre)
        losses.append(dia_pre_loss)
        losses.append(sum_pre_loss)
        losses.append(pre_kl_loss)
    else:
        dia_pre_loss = None
        sum_pre_loss = None
        pre_kl_loss = None

    prior_loss = None
    topic_loss = None
    kl_loss = None
    # --------------------------------------------------------------
    # Plot Norms of loss gradient wrt to the compressor
    # --------------------------------------------------------------
    if self.config["plot_norms"] and self.step % self.config["log_interval"] == 0:
        batch_outputs["grad_norm"] = self._debug_grad_norms(
            mean_rec_loss,
            prior_loss,
            topic_loss,
            kl_loss)

    return losses, batch_outputs