Beam Search performs worse than Greedy Search?

About the implementation of beam search. Why does it seem to perform worse than greedy search?

Other code dependencies, please import from RNN · Github (TullyMonster) for testing.

from dataclasses import dataclass
from typing import Optional, Callable, Self

import torch
from torch import nn, Tensor

from RNN.encoder_decoder import EncoderDecoder
from RNN.text_preprocessing import Vocabulary, ST


@dataclass
class Beam:
    """
    Beam search candidate sequence

    :param tokens: List of token indices for the candidate sequence
    :param score: Cumulative log probability score
    :param finished: Whether the sequence needs further extension (whether EOS token has been encountered)
    :param attn_weights: Optional list of attention weights
    """
    tokens: list[int]
    score: float
    finished: bool = False
    attn_weights: Optional[list[Tensor]] = None

    def extend(self, token_idx: int, token_score: float, eos_idx: int, attn_weight: Optional[Tensor] = None) -> Self:
        """
        Create a new candidate sequence by extending the current sequence

        :param token_idx: Token index to add
        :param token_score: Score of the token
        :param eos_idx: Index of the end-of-sequence token, used to determine if the sequence is complete
        :param attn_weight: Optional attention weight

        :return: New candidate sequence
        """
        is_eos: bool = token_idx == eos_idx
        new_beam = Beam(tokens=self.tokens.copy() if is_eos else self.tokens + [token_idx],
                        score=self.score + token_score,
                        finished=is_eos)
        if attn_weight is not None:
            new_beam.attn_weights = self.attn_weights.copy() if self.attn_weights else []
            new_beam.attn_weights.append(attn_weight)

        return new_beam


def forecast_beam_search(
        module: EncoderDecoder,
        src_sentence: str,
        src_vocab: Vocabulary,
        tgt_vocab: Vocabulary,
        device: torch.device,
        max_length: int = 20,
        beam_size: int = 3,
        alpha: float = 0.7,
        record_attn_weights: bool = False,
) -> tuple[str, Tensor | None]:
    """
    Implement sequence prediction (generation) using beam search

    :param module: Sequence-to-sequence model
    :param src_sentence: Source language sentence
    :param src_vocab: Source language vocabulary
    :param tgt_vocab: Target language vocabulary
    :param device: Computation device
    :param max_length: Maximum length of the generated sequence
    :param beam_size: Beam width
    :param alpha: Length penalty factor
    :param record_attn_weights: Whether to save attention weights

    :return: Generated target language sentence, and attention weights if necessary
    """
    length_penalty: Callable[[int], float] = lambda length: ((5 + length) ** alpha) / (
                (5 + 1) ** alpha)  # Length penalty
    score_normalize: Callable[[Beam], float] = lambda beam: beam.score / length_penalty(length) \
        if (length := len(beam.tokens)) > 0 else beam.score  # Score normalization

    # Get special token indices
    pad_src_index: int = src_vocab.get_index(ST.PAD)
    sos_tgt_index: int = tgt_vocab.get_index(ST.SOS)
    eos_tgt_index: int = tgt_vocab.get_index(ST.EOS)

    # Input preprocessing
    src_tokens: list[int] = src_vocab.encode(
        [*src_sentence.lower().split(), ST.EOS])  # Case conversion, tokenization, add EOS token
    src_tokens_pad_trunc: list[int] = [  # Truncate or pad to specified length
        *src_tokens[:max_length],
        *[pad_src_index] * (max_length - len(src_tokens))
    ]

    # Initialize beam set (candidate sequences). Use empty Beam instance to represent the starting point of search
    beams: list[Beam] = [Beam(tokens=[], score=0.0, finished=False, attn_weights=[] if record_attn_weights else None)]

    module.eval()
    with torch.no_grad():
        src_input = torch.tensor([src_tokens_pad_trunc], dtype=torch.long, device=device)  # (BATCH_SIZE=1, SEQ_LENGTH)
        src_valid_length = torch.tensor([len(src_tokens)], device=device)  # (BATCH_SIZE=1,)
        enc_output = module.encoder(src_input, src_valid_length)  # Encoder output

        # Decoding
        for step in range(max_length):
            if all(beam.finished for beam in
                   beams): break  # End decoding when all candidate sequences are marked as "finished"

            # Process finished and unfinished candidate sequences separately
            beams_done = [beam for beam in beams if beam.finished]
            beams_live = [beam for beam in beams if not beam.finished]

            candidates = beams_done.copy()  # Add finished sequences directly to the candidate set

            for beam in beams_live:  # Continue decoding only for unfinished sequences
                # Initialize decoder hidden state
                dec_state = module.decoder.init_state(enc_output)
                dec_input = torch.tensor([[sos_tgt_index] + (beam.tokens if step > 0 else [])], dtype=torch.long,
                                         device=device)
                if step > 0:
                    _, dec_state = module.decoder(dec_input[:, :-1], dec_state)

                output, state = module.decoder(dec_input[:, -1:], dec_state)  # Get prediction for current time step
                probs: Tensor = nn.functional.log_softmax(output[0][-1], dim=-1)  # Convert to log probabilities
                topk_probs, topk_indices = probs.topk(beam_size)  # Select top-k tokens with highest probabilities

                # Create new candidate sequences
                for i in range(beam_size):
                    attn_weight = output[1].squeeze(0) if record_attn_weights and len(output) > 1 else None
                    new_beam = beam.extend(token_idx=topk_indices[0, i].item(),
                                           token_score=topk_probs[0, i].item(),
                                           eos_idx=eos_tgt_index,
                                           attn_weight=attn_weight)
                    candidates.append(new_beam)

            beams = sorted(candidates, key=score_normalize, reverse=True)[
                    :beam_size]  # Select beam_size candidates with highest scores

    best_beam = max(beams, key=score_normalize)  # Use adjusted score to select the best sequence
    tgt_sentence = ' '.join(tgt_vocab.decode(best_beam.tokens))
    stack_attn_weights = torch.stack(best_beam.attn_weights) if record_attn_weights and best_beam.attn_weights else None

    return tgt_sentence, stack_attn_weights


if __name__ == '__main__':
    from torch import optim

    from RNN.encoder_decoder import EncoderDecoder
    from RNN.seq2seq import TestSentence, SequenceLengthCrossEntropyLoss, Seq2SeqEncoder, Seq2SeqDecoder, \
        train_one_epoch, forecast_greedy_search, evaluate_bleu
    from translation_dataset_loader import nmt_eng_fra_dataloader

    BATCH_SIZE = 128
    SEQ_LENGTH = 20
    EMBED_DIM = 256
    HIDDEN_NUM = 256
    NUM_LAYERS = 2
    DROPOUT = 0.2
    LEARNING_RATE = 0.0005
    EPOCHS_NUM = 50
    TEST_INTERVAL = 1
    TEST_SENTENCES = TestSentence(src=["I like apples .",
                                       "She reads books regularly .",
                                       "They play soccer together .",
                                       "We studied French yesterday .",
                                       "The weather is beautiful today ."],
                                  tgt=[["J'aime les pommes .", "J'adore les pommes .", "Les pommes me plaisent .",
                                        "Je raffole des pommes .", "J'apprécie les pommes ."],
                                       ["Elle lit des livres régulièrement .", "Elle lit des livres souvent .",
                                        "Elle lit des livres fréquemment .", "Elle lit régulièrement des ouvrages ."],
                                       ["Ils jouent au football ensemble .", "Ils jouent au foot ensemble .",
                                        "Ils pratiquent le football ensemble .", "Ensemble, ils jouent au football ."],
                                       ["Nous avons étudié le français hier .",
                                        "Hier, nous avons étudié le français .",
                                        "Nous avons appris le français hier .", "Nous avons fait du français hier ."],
                                       ["Le temps est magnifique aujourd'hui .", "Il fait beau aujourd'hui .",
                                        "Le temps est splendide aujourd'hui .", "La météo est belle aujourd'hui ."]])

    data_iter, eng_vocab, fra_vocab = nmt_eng_fra_dataloader(BATCH_SIZE, SEQ_LENGTH, num_workers=8)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    nmt_model = EncoderDecoder(encoder=Seq2SeqEncoder(len(eng_vocab), EMBED_DIM, HIDDEN_NUM, NUM_LAYERS, DROPOUT),
                               decoder=Seq2SeqDecoder(len(fra_vocab), EMBED_DIM, HIDDEN_NUM, NUM_LAYERS, DROPOUT),
                               device=device)  # 使用默认的模型参数初始化方法,不手动初始化
    optimizer = optim.Adam(nmt_model.parameters(), lr=LEARNING_RATE)
    criterion = SequenceLengthCrossEntropyLoss()

    for epoch in range(EPOCHS_NUM):
        loss, speed = train_one_epoch(nmt_model, data_iter, optimizer, criterion, fra_vocab, device)
        print(f'Epoch {epoch + 1:03}: loss {loss:.3f}, speed {speed:.1f} tokens/sec')

        if (epoch + 1) % TEST_INTERVAL == 0:
            for eng, fra in zip(TEST_SENTENCES.src, TEST_SENTENCES.tgt):
                forecast_fra_greedy, _ = forecast_greedy_search(nmt_model, eng, eng_vocab, fra_vocab, device)
                forecast_fra_beam, _ = forecast_beam_search(nmt_model, eng, eng_vocab, fra_vocab, device, alpha=0.3)
                print(f'INFO (greedy search): '
                      f'{eng.ljust(max(map(len, TEST_SENTENCES.src)))} '
                      f'→ (BLEU={evaluate_bleu(forecast_fra_greedy, fra, max_n_gram=3):.2f}) {forecast_fra_greedy}')
                print(f'INFO (beam search  ): '
                      f'{eng.ljust(max(map(len, TEST_SENTENCES.src)))} '
                      f'→ (BLEU={evaluate_bleu(forecast_fra_beam, fra, max_n_gram=3):.2f}) {forecast_fra_beam}')

Output:

Epoch 031: loss 0.024, speed 47454.5 tokens/sec
INFO (greedy search): I like apples .                  → (BLEU=0.63) j'aime les pommes .
INFO (beam search  ): I like apples .                  → (BLEU=0.00) j'aime les pommes pommes .
INFO (greedy search): She reads books regularly .      → (BLEU=0.00) elle lit avec deux balles .
INFO (beam search  ): She reads books regularly .      → (BLEU=0.00) elle lit <UNK> avec de l'argent .
INFO (greedy search): They play soccer together .      → (BLEU=1.00) ils jouent au foot ensemble .
INFO (beam search  ): They play soccer together .      → (BLEU=0.41) ils jouent jouent au foot restaurant .
INFO (greedy search): We studied French yesterday .    → (BLEU=1.00) nous avons étudié le français hier .
INFO (beam search  ): We studied French yesterday .    → (BLEU=1.00) nous avons étudié le français hier .
INFO (greedy search): The weather is beautiful today . → (BLEU=0.00) aujourd'hui , il fait beau .
INFO (beam search  ): The weather is beautiful today . → (BLEU=0.58) aujourd'hui il fait beau aujourd'hui .

Truncated. Training not yet complete.