beam_search_v0
unknown
python
3 years ago
1.4 kB
7
Indexable
# function to generate output sequence using beam search algorithm
def beam_search_decode(model, opt, src, src_mask, max_len, start_symbol, beam_size):
src = src.to(DEVICE)
src_mask = src_mask.to(DEVICE)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
# initialize list to hold beams
beams = [(ys, 0, memory)]
for i in range(max_len-1):
new_beams = []
for beam in beams:
ys, score, memory = beam
size = ys.size(0)
tgt_mask = nopeak_mask(size, opt).to(DEVICE)
out = model.decode(ys.transpose(0, 1), memory, src_mask, tgt_mask)
out = out.view(-1, out.size(-1))
prob = model.out(out) # out[:, -1]
# get top beam_size number of next words
top_indices = np.argsort(prob[-1].cpu().data.numpy())[-beam_size:]
for next_word in top_indices:
new_ys = torch.cat([ys, torch.ones(1, 1).type_as(
src.data).fill_(next_word)], dim=0)
new_score = score + prob[-1][next_word]
new_memory = memory
new_beams.append((new_ys, new_score, new_memory))
if next_word == EOS_IDX:
break
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
return beams[0][0]Editor is loading...