Untitled
unknown
python
4 years ago
5.5 kB
4
Indexable
def _process_batch(self, inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, sent_len, sent_num, y_lengths):
"""
处理每个 batch 数据
:param inp_x: 输入的对话数据
:param inp_sim: k 个和 n-th utterance 相似的其他 utterance
:param out_sim: n-th utterance 分类的 ground truth 标签
:param inp_y: 输入的 n-th utterance
:param out_y: 期望输出的 n-th utterance (ground truth)
:param sim_len: 每个相似 utterance 的实际句长
:param sent_len: 对话中每个句子的实际长度
:param sent_num: 对话中 utterance 的实际数量
:param y_lengths: n-th utterance 的实际长度
:return: losses, batch_outputs,一个 loss list 和该 batch 对应的输出,batch_outputs 是一个字典类型的数据
"""
self.model.train()
# 获得模型输出
outputs = self.model(self.config["model"]["k"], inp_x, inp_sim, inp_y, sim_len, sent_len, sent_num, y_lengths)
sent_prob, outs_enc, outs_enc_filter, dec1, dec2, sent_len, dialogue_pre, summary_pre = outputs
batch_outputs = {"model_outputs": outputs}
# --------------------------------------------------------------
# 1 - 预测 n-th utterance 损失计算,基于全部 dialogue 生成
# --------------------------------------------------------------
_dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1))
_x_labels = out_y.contiguous().view(-1)
nsent_loss = F.cross_entropy(_dec1_logits, _x_labels, ignore_index=0, reduction='none')
nsent_loss_token = nsent_loss.view(out_y.size())
batch_outputs["n_sent"] = nsent_loss_token
mean_rec_loss = nsent_loss.sum() / y_lengths.float().sum()
losses = [mean_rec_loss]
# --------------------------------------------------------------
# 1.5 - 从摘要预测 n-th utterance 损失计算
# --------------------------------------------------------------
if self.config["model"]["n_sent_sum_loss"]:
_dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1))
nsent_loss_sum = F.cross_entropy(_dec2_logits, _x_labels, ignore_index=0, reduction='none')
nsent_loss_token_sum = nsent_loss_sum.view(out_y.size())
batch_outputs["n_sent_sum"] = nsent_loss_token_sum
mean_rec_sum_loss = nsent_loss_sum.sum() / y_lengths.float().sum()
losses.append(mean_rec_sum_loss)
else:
mean_rec_sum_loss = None
# --------------------------------------------------------------
# 2 - KL 散度计算,让两个解码器分布相似
# --------------------------------------------------------------
if self.config["model"]["doc_sum_kl_loss"]:
_dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1))
_dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1))
#kl_loss = torch.nn.functional.kl_div(_dec2_logits, _dec2_logits, size_average=None, reduce=True, reduction='mean')
kl_loss = kl_categorical(_dec1_logits, _dec2_logits)
losses.append(kl_loss)
else:
kl_loss = None
# --------------------------------------------------------------
# 3 - 长度损失
# --------------------------------------------------------------
if self.config["model"]["length_loss"]:
_, topk_indices = torch.topk(sent_prob, k=self.config["model"]["k"], dim=1)
topk_indices = torch.squeeze(topk_indices, -1)
sum_length = torch.gather(sent_len, dim=1, index=topk_indices)
sum_length = torch.sum(sum_length, dim=1)
tmp = torch.sub(self.config["data"]["ext_sum_len"], sum_length)
length_loss = torch.mean(tmp.float())
losses.append(length_loss)
else:
length_loss = None
# --------------------------------------------------------------
# 4 - 文档摘要的 representation 和文档的 representation 尽量相近
# --------------------------------------------------------------
if self.config["model"]["doc_sum_sim_loss"]:
dialog_rep = torch.squeeze(torch.sum(outs_enc, 1))
summary_rep = torch.squeeze(torch.sum(outs_enc_filter, 1))
sim_loss = pairwise_loss(dialog_rep, summary_rep)
losses.append(sim_loss)
else:
sim_loss = None
# --------------------------------------------------------------
# 4 - n-th utterance 分类任务
# --------------------------------------------------------------
if self.config["model"]["nsent_classification"]:
criterion = torch.nn.BCEWithLogitsLoss()
dia_pre_loss = criterion(dialogue_pre, out_sim.float())
sum_pre_loss = criterion(summary_pre, out_sim.float())
pre_kl_loss = kl_categorical(dialogue_pre, summary_pre)
losses.append(dia_pre_loss)
losses.append(sum_pre_loss)
losses.append(pre_kl_loss)
else:
dia_pre_loss = None
sum_pre_loss = None
pre_kl_loss = None
prior_loss = None
topic_loss = None
kl_loss = None
# --------------------------------------------------------------
# Plot Norms of loss gradient wrt to the compressor
# --------------------------------------------------------------
if self.config["plot_norms"] and self.step % self.config["log_interval"] == 0:
batch_outputs["grad_norm"] = self._debug_grad_norms(
mean_rec_loss,
prior_loss,
topic_loss,
kl_loss)
return losses, batch_outputsEditor is loading...