beam_search_v0
unknown
python
2 years ago
1.4 kB
4
Indexable
# function to generate output sequence using beam search algorithm def beam_search_decode(model, opt, src, src_mask, max_len, start_symbol, beam_size): src = src.to(DEVICE) src_mask = src_mask.to(DEVICE) memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE) # initialize list to hold beams beams = [(ys, 0, memory)] for i in range(max_len-1): new_beams = [] for beam in beams: ys, score, memory = beam size = ys.size(0) tgt_mask = nopeak_mask(size, opt).to(DEVICE) out = model.decode(ys.transpose(0, 1), memory, src_mask, tgt_mask) out = out.view(-1, out.size(-1)) prob = model.out(out) # out[:, -1] # get top beam_size number of next words top_indices = np.argsort(prob[-1].cpu().data.numpy())[-beam_size:] for next_word in top_indices: new_ys = torch.cat([ys, torch.ones(1, 1).type_as( src.data).fill_(next_word)], dim=0) new_score = score + prob[-1][next_word] new_memory = memory new_beams.append((new_ys, new_score, new_memory)) if next_word == EOS_IDX: break beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size] return beams[0][0]
Editor is loading...