束搜索

http://zh-v2.d2l.ai/chapter_recurrent-modern/beam-search.html

结果,我们得到六个候选输出序列:(1)A;(2)C;(3)A,B;(4)C,E;(5)A,B,D ;(6)C,E,D。

把 A,C, AB,CE 这些最终序列的 prefix 也作为候选项,似乎会导致结果很强烈地偏向这些 prefix,即便 alpha 设置很大。

例如假设第一步选择 A 的条件概率是 0.9,第二步 P(B|A) 不那么有把握,假设是 0.5,P(D|AB) 也是 0.5, 那么 序列 A 的总体概率就是 0.9,序列 ABD 的总体概率是 0.225, 这两个悬殊很大 ln(0.225)/ln(0.9) 大概是 14, 而长度比是 3:1, 因此至少 alpha=2.4以上才能不选择 A , 0.75 的话,选出来基本就是第一个词构成的序列。

另外这使得 k=1 的时候 beam search不会退化成之前的贪心搜索,因为此时至少有可能会返回序列 A,而贪心搜索返回的一定是 ABD (假设 P(A) 比 P(C) 更大,B,D都不是 <eos>)

这 0.9, 0.5 的例子还算极端,现实词典很大的话, 最大条件概率可能也会比较小,似乎会导致更大的悬殊。

因此个人理解候选序列中应该要么是以 <eos> 结尾的序列,要么是 length = max_step 的序列, 除非能设计出更玄的启发式惩罚或者说奖励函数。。

Here is my bin search demo.
code:

def predict_seq2seq_bin_search(net, src_sentence, src_vocab, tgt_vocab, num_steps, bin_size, device):
    net.eval()
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    
    enc_X = torch.tensor(src_tokens, dtype=torch.long, device=device).unsqueeze(0)
    enc_outputs = net.encoder(enc_X, enc_valid_len)
    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
    
    dec_X = torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device).unsqueeze(0)
    output_seq, attention_weight_seq = [], []
    import queue
    q = queue.deque()
    
    # 初始化time step = 1的序列
    Y_hat, dec_state = net.decoder(dec_X, dec_state)
    p_Y = Y_hat.softmax(2)
    values, indices = p_Y.topk(k=bin_size)
    for i in range(bin_size):
        # 保存time step = 1时的k个状态
        cum_log_p = values[..., i].log()
        seq_dec_X = [indices[..., i]]
        q.append((cum_log_p, seq_dec_X, dec_state))
    
    # 所有被搜出的序列
    candidate_seqs = []
    for t in range(2, num_steps + 1):
        q_tilda = queue.deque()
        while q:
            cum_log_p, seq_dec_X, dec_state = q.popleft()
            dec_X = seq_dec_X[-1]
            # 直接加入候选序列
            candidate_seqs.append((cum_log_p, seq_dec_X))
            
            Y_hat, dec_state = net.decoder(dec_X, dec_state)
            p_Y = Y_hat.softmax(2)
            values, indices = p_Y.topk(k=bin_size)
            # 估计前一个预测词元,预测当前前K大概率的预测词元,加入临时队列
            for i in range(bin_size):
                # 如果预测序列结束,不再延长
                last_pred = indices[..., i].type(torch.int32).item()
                if last_pred == tgt_vocab['<eos>']:
                    continue
                q_tilda.append((cum_log_p + values[..., i].log(), seq_dec_X + [indices[..., i]], dec_state))
        # 临时队列没有预测序列,则结束
        if not q_tilda:
            break
        # 将临时队列中前K大概率的预测序列加入队列q
        p_tilda = torch.stack([p for p, _, _ in q_tilda], dim=-1)
        values, indices = p_tilda.topk(k=bin_size)
        for i in range(bin_size):
            q.append(q_tilda[indices[..., i].item()])
    
    results = []
    for log_p, seq in candidate_seqs:
        line = [idx.item() for idx in seq]
        log_p = log_p.item()
        score = 1 / (len(line) ** 0.75) * log_p
        results.append((score, line))
    results.sort()
    results = results[::-1]
    output_seq = results[0][1]
    print('candidate_seqs: ')
    for result in results:
        print(result[0], ' '.join(tgt_vocab.to_tokens(result[1])))
        
    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

def bleu(pred_seq, label_seq, k):
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_label))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= (num_matches / (len_pred - n + 1)) ** (0.5 ** n)    
    return score
    
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, attention_weight_seq = predict_seq2seq_bin_search(
        net, eng, src_vocab, tgt_vocab, num_steps, 2, device)
    
    print(translation, bleu(translation, fra, k=2))

output:

candidate_seqs: 
-0.016121856130327874 va !
-0.024113455787301064 va
-3.013015800556802 <unk> !
-4.11942285821114 va ! faire !
-4.232856580914601 va ! faire foutre !
-5.0038933560003835 va ! faire foutre
-5.041252840753798 va ! faire
-5.058718681335449 <unk>
-6.79943090721856 va ! faire foutre ! question !
-6.893294114463412 va ! faire foutre .
-7.6312102550458345 va ! faire foutre ! question
-7.720330952899213 <unk> ! foutre
-7.988904374736352 va ! faire foutre ! question .
-8.60868592393411 va ! faire foutre ! question ! la !
-9.142400449283082 va ! faire foutre ! question ! la .
-9.364281554470319 va ! faire foutre ! question ! la
-9.814213404925354 va ! faire foutre . foutre
-10.437807511908906 va ! faire foutre ! question . foutre
va ! 1.0
candidate_seqs: 
-0.0038098586909827547 j'ai perdu .
-0.004831752904161648 j'ai perdu
-0.007010838482528925 j'ai
-2.9811325459624705 je <unk> .
-3.514913302269796 j'ai perdu . la .
-3.5598168006984148 j'ai perdu . la <unk> .
-3.893556157542254 je <unk>
-4.069746896682259 j'ai perdu . la
-4.081440138478644 j'ai perdu . la <unk>
-5.369591335646937 j'ai perdu . la . la <unk> .
-5.860156536102295 je
-5.876807906301628 j'ai perdu . la . la .
-5.935191113321948 j'ai perdu . la . la <unk>
-6.370897772736185 j'ai perdu . la . la
-6.553729744910139 je <unk> . la
-7.360312207266691 j'ai perdu . la . la <unk> . la
-7.600138172371028 j'ai perdu . la . la <unk> !
-9.224490385094755 j'ai perdu . la . la <unk> ! la
j'ai perdu . 1.0
candidate_seqs: 
-0.0003027543974668561 il est
-0.0004033067380078137 il
-0.11327251082201431 il est mouillé .
-0.1402585905965848 il est mouillé
-0.6997945558509984 il est paresseux .
-0.8680777901962664 il est paresseux
-4.115854197915852 il est mouillé . .
-4.493449958406501 il est paresseux . .
-5.780498261250492 demande est
-6.968914267928182 il est mouillé . . .
-7.453912034770484 il est paresseux . . .
-9.355443446917416 il est mouillé . . . .
-9.50432300567627 demande
-9.82209869345339 il est paresseux . . . .
-11.266801514180447 il est mouillé . . . . .
-11.763471837877688 il est paresseux . . . . .
-12.902149315494244 il est mouillé . . . . . .
-13.409348214542925 il est paresseux . . . . . .
il est 1.0
candidate_seqs: 
-0.0011929599168567543 je suis
-0.0019314452074468136 je
-0.005115384636352064 je suis chez
-0.04757358189474759 je suis chez moi
-0.05575660180762165 je suis chez moi .
-0.9392764348611705 je suis chez occupé .
-1.1098171370161383 je suis chez occupé
-2.7762335653889516 je suis occupé
-3.1791521902178976 je suis chez moi . partie .
-3.568505363270929 je suis chez moi . partie
-4.101979952741184 je suis chez occupé . la .
-4.428988738824356 je suis chez occupé . la
-4.82773908057205 rentre chez
-5.503399801449647 je suis chez moi . partie . la partie
-5.547864755099986 je suis chez moi . partie . la .
-5.878285967406078 je suis chez moi . partie . la
-6.649581094077673 je suis chez occupé . la . la
-7.40877103805542 rentre
je suis 1.0