Untitled
unknown
python
4 years ago
5.5 kB
3
Indexable
def _process_batch(self, inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, sent_len, sent_num, y_lengths): """ 处理每个 batch 数据 :param inp_x: 输入的对话数据 :param inp_sim: k 个和 n-th utterance 相似的其他 utterance :param out_sim: n-th utterance 分类的 ground truth 标签 :param inp_y: 输入的 n-th utterance :param out_y: 期望输出的 n-th utterance (ground truth) :param sim_len: 每个相似 utterance 的实际句长 :param sent_len: 对话中每个句子的实际长度 :param sent_num: 对话中 utterance 的实际数量 :param y_lengths: n-th utterance 的实际长度 :return: losses, batch_outputs,一个 loss list 和该 batch 对应的输出,batch_outputs 是一个字典类型的数据 """ self.model.train() # 获得模型输出 outputs = self.model(self.config["model"]["k"], inp_x, inp_sim, inp_y, sim_len, sent_len, sent_num, y_lengths) sent_prob, outs_enc, outs_enc_filter, dec1, dec2, sent_len, dialogue_pre, summary_pre = outputs batch_outputs = {"model_outputs": outputs} # -------------------------------------------------------------- # 1 - 预测 n-th utterance 损失计算,基于全部 dialogue 生成 # -------------------------------------------------------------- _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1)) _x_labels = out_y.contiguous().view(-1) nsent_loss = F.cross_entropy(_dec1_logits, _x_labels, ignore_index=0, reduction='none') nsent_loss_token = nsent_loss.view(out_y.size()) batch_outputs["n_sent"] = nsent_loss_token mean_rec_loss = nsent_loss.sum() / y_lengths.float().sum() losses = [mean_rec_loss] # -------------------------------------------------------------- # 1.5 - 从摘要预测 n-th utterance 损失计算 # -------------------------------------------------------------- if self.config["model"]["n_sent_sum_loss"]: _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1)) nsent_loss_sum = F.cross_entropy(_dec2_logits, _x_labels, ignore_index=0, reduction='none') nsent_loss_token_sum = nsent_loss_sum.view(out_y.size()) batch_outputs["n_sent_sum"] = nsent_loss_token_sum mean_rec_sum_loss = nsent_loss_sum.sum() / y_lengths.float().sum() losses.append(mean_rec_sum_loss) else: mean_rec_sum_loss = None # -------------------------------------------------------------- # 2 - KL 散度计算,让两个解码器分布相似 # -------------------------------------------------------------- if self.config["model"]["doc_sum_kl_loss"]: _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1)) _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1)) #kl_loss = torch.nn.functional.kl_div(_dec2_logits, _dec2_logits, size_average=None, reduce=True, reduction='mean') kl_loss = kl_categorical(_dec1_logits, _dec2_logits) losses.append(kl_loss) else: kl_loss = None # -------------------------------------------------------------- # 3 - 长度损失 # -------------------------------------------------------------- if self.config["model"]["length_loss"]: _, topk_indices = torch.topk(sent_prob, k=self.config["model"]["k"], dim=1) topk_indices = torch.squeeze(topk_indices, -1) sum_length = torch.gather(sent_len, dim=1, index=topk_indices) sum_length = torch.sum(sum_length, dim=1) tmp = torch.sub(self.config["data"]["ext_sum_len"], sum_length) length_loss = torch.mean(tmp.float()) losses.append(length_loss) else: length_loss = None # -------------------------------------------------------------- # 4 - 文档摘要的 representation 和文档的 representation 尽量相近 # -------------------------------------------------------------- if self.config["model"]["doc_sum_sim_loss"]: dialog_rep = torch.squeeze(torch.sum(outs_enc, 1)) summary_rep = torch.squeeze(torch.sum(outs_enc_filter, 1)) sim_loss = pairwise_loss(dialog_rep, summary_rep) losses.append(sim_loss) else: sim_loss = None # -------------------------------------------------------------- # 4 - n-th utterance 分类任务 # -------------------------------------------------------------- if self.config["model"]["nsent_classification"]: criterion = torch.nn.BCEWithLogitsLoss() dia_pre_loss = criterion(dialogue_pre, out_sim.float()) sum_pre_loss = criterion(summary_pre, out_sim.float()) pre_kl_loss = kl_categorical(dialogue_pre, summary_pre) losses.append(dia_pre_loss) losses.append(sum_pre_loss) losses.append(pre_kl_loss) else: dia_pre_loss = None sum_pre_loss = None pre_kl_loss = None prior_loss = None topic_loss = None kl_loss = None # -------------------------------------------------------------- # Plot Norms of loss gradient wrt to the compressor # -------------------------------------------------------------- if self.config["plot_norms"] and self.step % self.config["log_interval"] == 0: batch_outputs["grad_norm"] = self._debug_grad_norms( mean_rec_loss, prior_loss, topic_loss, kl_loss) return losses, batch_outputs
Editor is loading...