About the implementation of beam search. Why does it seem to perform worse than greedy search?
Other code dependencies, please import from RNN · Github (TullyMonster) for testing.
from dataclasses import dataclass
from typing import Optional, Callable, Self
import torch
from torch import nn, Tensor
from RNN.encoder_decoder import EncoderDecoder
from RNN.text_preprocessing import Vocabulary, ST
@dataclass
class Beam:
"""
Beam search candidate sequence
:param tokens: List of token indices for the candidate sequence
:param score: Cumulative log probability score
:param finished: Whether the sequence needs further extension (whether EOS token has been encountered)
:param attn_weights: Optional list of attention weights
"""
tokens: list[int]
score: float
finished: bool = False
attn_weights: Optional[list[Tensor]] = None
def extend(self, token_idx: int, token_score: float, eos_idx: int, attn_weight: Optional[Tensor] = None) -> Self:
"""
Create a new candidate sequence by extending the current sequence
:param token_idx: Token index to add
:param token_score: Score of the token
:param eos_idx: Index of the end-of-sequence token, used to determine if the sequence is complete
:param attn_weight: Optional attention weight
:return: New candidate sequence
"""
is_eos: bool = token_idx == eos_idx
new_beam = Beam(tokens=self.tokens.copy() if is_eos else self.tokens + [token_idx],
score=self.score + token_score,
finished=is_eos)
if attn_weight is not None:
new_beam.attn_weights = self.attn_weights.copy() if self.attn_weights else []
new_beam.attn_weights.append(attn_weight)
return new_beam
def forecast_beam_search(
module: EncoderDecoder,
src_sentence: str,
src_vocab: Vocabulary,
tgt_vocab: Vocabulary,
device: torch.device,
max_length: int = 20,
beam_size: int = 3,
alpha: float = 0.7,
record_attn_weights: bool = False,
) -> tuple[str, Tensor | None]:
"""
Implement sequence prediction (generation) using beam search
:param module: Sequence-to-sequence model
:param src_sentence: Source language sentence
:param src_vocab: Source language vocabulary
:param tgt_vocab: Target language vocabulary
:param device: Computation device
:param max_length: Maximum length of the generated sequence
:param beam_size: Beam width
:param alpha: Length penalty factor
:param record_attn_weights: Whether to save attention weights
:return: Generated target language sentence, and attention weights if necessary
"""
length_penalty: Callable[[int], float] = lambda length: ((5 + length) ** alpha) / (
(5 + 1) ** alpha) # Length penalty
score_normalize: Callable[[Beam], float] = lambda beam: beam.score / length_penalty(length) \
if (length := len(beam.tokens)) > 0 else beam.score # Score normalization
# Get special token indices
pad_src_index: int = src_vocab.get_index(ST.PAD)
sos_tgt_index: int = tgt_vocab.get_index(ST.SOS)
eos_tgt_index: int = tgt_vocab.get_index(ST.EOS)
# Input preprocessing
src_tokens: list[int] = src_vocab.encode(
[*src_sentence.lower().split(), ST.EOS]) # Case conversion, tokenization, add EOS token
src_tokens_pad_trunc: list[int] = [ # Truncate or pad to specified length
*src_tokens[:max_length],
*[pad_src_index] * (max_length - len(src_tokens))
]
# Initialize beam set (candidate sequences). Use empty Beam instance to represent the starting point of search
beams: list[Beam] = [Beam(tokens=[], score=0.0, finished=False, attn_weights=[] if record_attn_weights else None)]
module.eval()
with torch.no_grad():
src_input = torch.tensor([src_tokens_pad_trunc], dtype=torch.long, device=device) # (BATCH_SIZE=1, SEQ_LENGTH)
src_valid_length = torch.tensor([len(src_tokens)], device=device) # (BATCH_SIZE=1,)
enc_output = module.encoder(src_input, src_valid_length) # Encoder output
# Decoding
for step in range(max_length):
if all(beam.finished for beam in
beams): break # End decoding when all candidate sequences are marked as "finished"
# Process finished and unfinished candidate sequences separately
beams_done = [beam for beam in beams if beam.finished]
beams_live = [beam for beam in beams if not beam.finished]
candidates = beams_done.copy() # Add finished sequences directly to the candidate set
for beam in beams_live: # Continue decoding only for unfinished sequences
# Initialize decoder hidden state
dec_state = module.decoder.init_state(enc_output)
dec_input = torch.tensor([[sos_tgt_index] + (beam.tokens if step > 0 else [])], dtype=torch.long,
device=device)
if step > 0:
_, dec_state = module.decoder(dec_input[:, :-1], dec_state)
output, state = module.decoder(dec_input[:, -1:], dec_state) # Get prediction for current time step
probs: Tensor = nn.functional.log_softmax(output[0][-1], dim=-1) # Convert to log probabilities
topk_probs, topk_indices = probs.topk(beam_size) # Select top-k tokens with highest probabilities
# Create new candidate sequences
for i in range(beam_size):
attn_weight = output[1].squeeze(0) if record_attn_weights and len(output) > 1 else None
new_beam = beam.extend(token_idx=topk_indices[0, i].item(),
token_score=topk_probs[0, i].item(),
eos_idx=eos_tgt_index,
attn_weight=attn_weight)
candidates.append(new_beam)
beams = sorted(candidates, key=score_normalize, reverse=True)[
:beam_size] # Select beam_size candidates with highest scores
best_beam = max(beams, key=score_normalize) # Use adjusted score to select the best sequence
tgt_sentence = ' '.join(tgt_vocab.decode(best_beam.tokens))
stack_attn_weights = torch.stack(best_beam.attn_weights) if record_attn_weights and best_beam.attn_weights else None
return tgt_sentence, stack_attn_weights
if __name__ == '__main__':
from torch import optim
from RNN.encoder_decoder import EncoderDecoder
from RNN.seq2seq import TestSentence, SequenceLengthCrossEntropyLoss, Seq2SeqEncoder, Seq2SeqDecoder, \
train_one_epoch, forecast_greedy_search, evaluate_bleu
from translation_dataset_loader import nmt_eng_fra_dataloader
BATCH_SIZE = 128
SEQ_LENGTH = 20
EMBED_DIM = 256
HIDDEN_NUM = 256
NUM_LAYERS = 2
DROPOUT = 0.2
LEARNING_RATE = 0.0005
EPOCHS_NUM = 50
TEST_INTERVAL = 1
TEST_SENTENCES = TestSentence(src=["I like apples .",
"She reads books regularly .",
"They play soccer together .",
"We studied French yesterday .",
"The weather is beautiful today ."],
tgt=[["J'aime les pommes .", "J'adore les pommes .", "Les pommes me plaisent .",
"Je raffole des pommes .", "J'apprécie les pommes ."],
["Elle lit des livres régulièrement .", "Elle lit des livres souvent .",
"Elle lit des livres fréquemment .", "Elle lit régulièrement des ouvrages ."],
["Ils jouent au football ensemble .", "Ils jouent au foot ensemble .",
"Ils pratiquent le football ensemble .", "Ensemble, ils jouent au football ."],
["Nous avons étudié le français hier .",
"Hier, nous avons étudié le français .",
"Nous avons appris le français hier .", "Nous avons fait du français hier ."],
["Le temps est magnifique aujourd'hui .", "Il fait beau aujourd'hui .",
"Le temps est splendide aujourd'hui .", "La météo est belle aujourd'hui ."]])
data_iter, eng_vocab, fra_vocab = nmt_eng_fra_dataloader(BATCH_SIZE, SEQ_LENGTH, num_workers=8)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nmt_model = EncoderDecoder(encoder=Seq2SeqEncoder(len(eng_vocab), EMBED_DIM, HIDDEN_NUM, NUM_LAYERS, DROPOUT),
decoder=Seq2SeqDecoder(len(fra_vocab), EMBED_DIM, HIDDEN_NUM, NUM_LAYERS, DROPOUT),
device=device) # 使用默认的模型参数初始化方法,不手动初始化
optimizer = optim.Adam(nmt_model.parameters(), lr=LEARNING_RATE)
criterion = SequenceLengthCrossEntropyLoss()
for epoch in range(EPOCHS_NUM):
loss, speed = train_one_epoch(nmt_model, data_iter, optimizer, criterion, fra_vocab, device)
print(f'Epoch {epoch + 1:03}: loss {loss:.3f}, speed {speed:.1f} tokens/sec')
if (epoch + 1) % TEST_INTERVAL == 0:
for eng, fra in zip(TEST_SENTENCES.src, TEST_SENTENCES.tgt):
forecast_fra_greedy, _ = forecast_greedy_search(nmt_model, eng, eng_vocab, fra_vocab, device)
forecast_fra_beam, _ = forecast_beam_search(nmt_model, eng, eng_vocab, fra_vocab, device, alpha=0.3)
print(f'INFO (greedy search): '
f'{eng.ljust(max(map(len, TEST_SENTENCES.src)))} '
f'→ (BLEU={evaluate_bleu(forecast_fra_greedy, fra, max_n_gram=3):.2f}) {forecast_fra_greedy}')
print(f'INFO (beam search ): '
f'{eng.ljust(max(map(len, TEST_SENTENCES.src)))} '
f'→ (BLEU={evaluate_bleu(forecast_fra_beam, fra, max_n_gram=3):.2f}) {forecast_fra_beam}')
Output:
Epoch 031: loss 0.024, speed 47454.5 tokens/sec
INFO (greedy search): I like apples . → (BLEU=0.63) j'aime les pommes .
INFO (beam search ): I like apples . → (BLEU=0.00) j'aime les pommes pommes .
INFO (greedy search): She reads books regularly . → (BLEU=0.00) elle lit avec deux balles .
INFO (beam search ): She reads books regularly . → (BLEU=0.00) elle lit <UNK> avec de l'argent .
INFO (greedy search): They play soccer together . → (BLEU=1.00) ils jouent au foot ensemble .
INFO (beam search ): They play soccer together . → (BLEU=0.41) ils jouent jouent au foot restaurant .
INFO (greedy search): We studied French yesterday . → (BLEU=1.00) nous avons étudié le français hier .
INFO (beam search ): We studied French yesterday . → (BLEU=1.00) nous avons étudié le français hier .
INFO (greedy search): The weather is beautiful today . → (BLEU=0.00) aujourd'hui , il fait beau .
INFO (beam search ): The weather is beautiful today . → (BLEU=0.58) aujourd'hui il fait beau aujourd'hui .
Truncated. Training not yet complete.