http://zh-v2.d2l.ai/chapter_recurrent-modern/beam-search.html
结果,我们得到六个候选输出序列:(1)A;(2)C;(3)A,B;(4)C,E;(5)A,B,D ;(6)C,E,D。
把 A,C, AB,CE 这些最终序列的 prefix 也作为候选项,似乎会导致结果很强烈地偏向这些 prefix,即便 alpha 设置很大。
例如假设第一步选择 A 的条件概率是 0.9,第二步 P(B|A) 不那么有把握,假设是 0.5,P(D|AB) 也是 0.5, 那么 序列 A 的总体概率就是 0.9,序列 ABD 的总体概率是 0.225, 这两个悬殊很大 ln(0.225)/ln(0.9) 大概是 14, 而长度比是 3:1, 因此至少 alpha=2.4以上才能不选择 A , 0.75 的话,选出来基本就是第一个词构成的序列。
另外这使得 k=1 的时候 beam search不会退化成之前的贪心搜索,因为此时至少有可能会返回序列 A,而贪心搜索返回的一定是 ABD (假设 P(A) 比 P(C) 更大,B,D都不是 <eos>)
这 0.9, 0.5 的例子还算极端,现实词典很大的话, 最大条件概率可能也会比较小,似乎会导致更大的悬殊。
因此个人理解候选序列中应该要么是以 <eos> 结尾的序列,要么是 length = max_step 的序列, 除非能设计出更玄的启发式惩罚或者说奖励函数。。
Here is my bin search demo.
code:
def predict_seq2seq_bin_search(net, src_sentence, src_vocab, tgt_vocab, num_steps, bin_size, device):
net.eval()
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
enc_X = torch.tensor(src_tokens, dtype=torch.long, device=device).unsqueeze(0)
enc_outputs = net.encoder(enc_X, enc_valid_len)
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
dec_X = torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device).unsqueeze(0)
output_seq, attention_weight_seq = [], []
import queue
q = queue.deque()
# 初始化time step = 1的序列
Y_hat, dec_state = net.decoder(dec_X, dec_state)
p_Y = Y_hat.softmax(2)
values, indices = p_Y.topk(k=bin_size)
for i in range(bin_size):
# 保存time step = 1时的k个状态
cum_log_p = values[..., i].log()
seq_dec_X = [indices[..., i]]
q.append((cum_log_p, seq_dec_X, dec_state))
# 所有被搜出的序列
candidate_seqs = []
for t in range(2, num_steps + 1):
q_tilda = queue.deque()
while q:
cum_log_p, seq_dec_X, dec_state = q.popleft()
dec_X = seq_dec_X[-1]
# 直接加入候选序列
candidate_seqs.append((cum_log_p, seq_dec_X))
Y_hat, dec_state = net.decoder(dec_X, dec_state)
p_Y = Y_hat.softmax(2)
values, indices = p_Y.topk(k=bin_size)
# 估计前一个预测词元,预测当前前K大概率的预测词元,加入临时队列
for i in range(bin_size):
# 如果预测序列结束,不再延长
last_pred = indices[..., i].type(torch.int32).item()
if last_pred == tgt_vocab['<eos>']:
continue
q_tilda.append((cum_log_p + values[..., i].log(), seq_dec_X + [indices[..., i]], dec_state))
# 临时队列没有预测序列,则结束
if not q_tilda:
break
# 将临时队列中前K大概率的预测序列加入队列q
p_tilda = torch.stack([p for p, _, _ in q_tilda], dim=-1)
values, indices = p_tilda.topk(k=bin_size)
for i in range(bin_size):
q.append(q_tilda[indices[..., i].item()])
results = []
for log_p, seq in candidate_seqs:
line = [idx.item() for idx in seq]
log_p = log_p.item()
score = 1 / (len(line) ** 0.75) * log_p
results.append((score, line))
results.sort()
results = results[::-1]
output_seq = results[0][1]
print('candidate_seqs: ')
for result in results:
print(result[0], ' '.join(tgt_vocab.to_tokens(result[1])))
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
def bleu(pred_seq, label_seq, k):
pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
len_pred, len_label = len(pred_tokens), len(label_tokens)
score = math.exp(min(0, 1 - len_label / len_label))
for n in range(1, k + 1):
num_matches, label_subs = 0, collections.defaultdict(int)
for i in range(len_label - n + 1):
label_subs[' '.join(label_tokens[i: i + n])] += 1
for i in range(len_pred - n + 1):
if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
num_matches += 1
label_subs[' '.join(pred_tokens[i: i + n])] -= 1
score *= (num_matches / (len_pred - n + 1)) ** (0.5 ** n)
return score
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, attention_weight_seq = predict_seq2seq_bin_search(
net, eng, src_vocab, tgt_vocab, num_steps, 2, device)
print(translation, bleu(translation, fra, k=2))
output:
candidate_seqs:
-0.016121856130327874 va !
-0.024113455787301064 va
-3.013015800556802 <unk> !
-4.11942285821114 va ! faire !
-4.232856580914601 va ! faire foutre !
-5.0038933560003835 va ! faire foutre
-5.041252840753798 va ! faire
-5.058718681335449 <unk>
-6.79943090721856 va ! faire foutre ! question !
-6.893294114463412 va ! faire foutre .
-7.6312102550458345 va ! faire foutre ! question
-7.720330952899213 <unk> ! foutre
-7.988904374736352 va ! faire foutre ! question .
-8.60868592393411 va ! faire foutre ! question ! la !
-9.142400449283082 va ! faire foutre ! question ! la .
-9.364281554470319 va ! faire foutre ! question ! la
-9.814213404925354 va ! faire foutre . foutre
-10.437807511908906 va ! faire foutre ! question . foutre
va ! 1.0
candidate_seqs:
-0.0038098586909827547 j'ai perdu .
-0.004831752904161648 j'ai perdu
-0.007010838482528925 j'ai
-2.9811325459624705 je <unk> .
-3.514913302269796 j'ai perdu . la .
-3.5598168006984148 j'ai perdu . la <unk> .
-3.893556157542254 je <unk>
-4.069746896682259 j'ai perdu . la
-4.081440138478644 j'ai perdu . la <unk>
-5.369591335646937 j'ai perdu . la . la <unk> .
-5.860156536102295 je
-5.876807906301628 j'ai perdu . la . la .
-5.935191113321948 j'ai perdu . la . la <unk>
-6.370897772736185 j'ai perdu . la . la
-6.553729744910139 je <unk> . la
-7.360312207266691 j'ai perdu . la . la <unk> . la
-7.600138172371028 j'ai perdu . la . la <unk> !
-9.224490385094755 j'ai perdu . la . la <unk> ! la
j'ai perdu . 1.0
candidate_seqs:
-0.0003027543974668561 il est
-0.0004033067380078137 il
-0.11327251082201431 il est mouillé .
-0.1402585905965848 il est mouillé
-0.6997945558509984 il est paresseux .
-0.8680777901962664 il est paresseux
-4.115854197915852 il est mouillé . .
-4.493449958406501 il est paresseux . .
-5.780498261250492 demande est
-6.968914267928182 il est mouillé . . .
-7.453912034770484 il est paresseux . . .
-9.355443446917416 il est mouillé . . . .
-9.50432300567627 demande
-9.82209869345339 il est paresseux . . . .
-11.266801514180447 il est mouillé . . . . .
-11.763471837877688 il est paresseux . . . . .
-12.902149315494244 il est mouillé . . . . . .
-13.409348214542925 il est paresseux . . . . . .
il est 1.0
candidate_seqs:
-0.0011929599168567543 je suis
-0.0019314452074468136 je
-0.005115384636352064 je suis chez
-0.04757358189474759 je suis chez moi
-0.05575660180762165 je suis chez moi .
-0.9392764348611705 je suis chez occupé .
-1.1098171370161383 je suis chez occupé
-2.7762335653889516 je suis occupé
-3.1791521902178976 je suis chez moi . partie .
-3.568505363270929 je suis chez moi . partie
-4.101979952741184 je suis chez occupé . la .
-4.428988738824356 je suis chez occupé . la
-4.82773908057205 rentre chez
-5.503399801449647 je suis chez moi . partie . la partie
-5.547864755099986 je suis chez moi . partie . la .
-5.878285967406078 je suis chez moi . partie . la
-6.649581094077673 je suis chez occupé . la . la
-7.40877103805542 rentre
je suis 1.0