beam_search_v0

mail@pastecode.io avatar
unknown
python
2 years ago
1.4 kB
2
Indexable
Never
# function to generate output sequence using beam search algorithm
def beam_search_decode(model, opt, src, src_mask, max_len, start_symbol, beam_size):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)

    # initialize list to hold beams
    beams = [(ys, 0, memory)]
    for i in range(max_len-1):
        new_beams = []
        for beam in beams:
            ys, score, memory = beam
            size = ys.size(0)
            tgt_mask = nopeak_mask(size, opt).to(DEVICE)
            out = model.decode(ys.transpose(0, 1), memory, src_mask, tgt_mask)
            out = out.view(-1, out.size(-1))
            prob = model.out(out)  # out[:, -1]

            # get top beam_size number of next words
            top_indices = np.argsort(prob[-1].cpu().data.numpy())[-beam_size:]
            for next_word in top_indices:
                new_ys = torch.cat([ys, torch.ones(1, 1).type_as(
                    src.data).fill_(next_word)], dim=0)
                new_score = score + prob[-1][next_word]
                new_memory = memory
                new_beams.append((new_ys, new_score, new_memory))
                if next_word == EOS_IDX:
                    break
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

    return beams[0][0]