Transformer

首先,目前的中文版代码相对英文版已经落后不少了。在中文版同步更新之前,建议后面的读者先参考英文版的代码。

然后说说大家讨论的位置编码bug对应的解法。
这个问题目前在英文版里也没有解决,所以这里也贴一下我的解法,应该是比较简单直观的:

问题分析:
预测模式下,由d2l.predict_seq2seq()函数将输入token一个一个单独传入TransformerDecoder进行预测,使得TransformerDecoder内部是获取不到每个token在序列中的位置的,所以其内部调用d2l.PositionalEncoding(试图根据token位置来映射出位置编码)是肯定会出bug的。

解决方案:
让predict_seq2seq()函数将token丢失的位置信息与token一起传入Decoder,并最终交给PositionalEncoding内部正确处理。即可

具体来说:
1. predict_seq2seq()函数(9.7节)里,调用decoder解码时添加offset参数

原代码

for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)

改为

for idx in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state, offset=idx)  // 传入表示位置的offset

2. TransformerDecoder.forward()函数,将offset透传给PositionalEncoding

原代码

def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

改为

def forward(self, X, state, offset=0):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens), offset) // 透传offset

3. PositionalEncoding.forward()函数(10.6节),支持根据offset计算出位置编码

原代码

def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)

改为

def forward(self, X, offset=0):
        X = X + self.P[:, offset:X.shape[1], :].to(X.device) // 根据offset计算出位置编码

如此修改,就可以解决预测场景下位置编码计算的bug,不影响其他方面的计算,也是在原有框架下改动最小的方案。

修改后的dec_valid_lens在推理时计算不对? 在推理时,输入X是之前的累计的序列,此时的dec_valid_lens应该是None,表示当前序列的长度。
在训练时,需要进行mask处理,所有需要arange。

//而我修改的代码没有再考虑这个,是因为我即使在预测时也总是送入至今为止预测的完整时间的X,设置为None的话,除了最后一个时间步,前面时间步的后续计算就出错了,因为它们会对“未来”产生注意力。
这个做法有问题,在推理时,原始代码,每次只有一次query,然后得到value。而你的代码相当于要做当前长度次的query。
在下面回复中,我认为@cddc提到的方法是有效的。

补充:完整开了@mashiro的代码,看起来是OK的。但推理时候的效率比较低,每次预测下个词是,需要对当前的序列整个进行一次attention,也就是要进行当前长度次的query,而实际只需要进行一次的query。

我认为最后一步有些问题,应该是
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device),预测的时候X是一维的,每次的X.shape[1]=1

residual connection主要是为了传梯度缓解梯度消失的,直接把X加上去并没有什么可学习的参数、模型不会学到如何走捷径“作弊”;而要学的参数都在attention里面,所以对要学的部分即attention做mask就行了

我有两个疑问,你的思路是把之前每步预测的结果保存起来,然后当作下一步的输入,而非像示例代码一样每步预测时只输入上一步的预测结果。我没理解错吧?
首先,你修改后的代码X2 = self.attention1(X, X, X, dec_valid_lens)中的X似乎都是t时刻及以前的输入,而示例中,自注意力的Q是当前的X,K与V才是t时刻及以前的输入;另外,在示例中,每个decoderblock之间传递的都是当前时刻的信息,而你修改好后传递的是所有历史时间步的信息

赞。你说的对,我代码的最后一步有个小bug。应该是你这样的才对:
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device)

特此更正

隔了很久,发现这个bug还有不少人在讨论,并且中文版也很久没有更新过了(看了眼英文版也还有这个bug)。我回来更新一下我当时的修改,虽然我当时的修改在逻辑上是正确的,但确实很多人提到了,预测时每次都用完整的当前预测序列做query进行注意力运算效率太低了。

我现在更新一下我的修改方案,我现在的修改方案只变更三处,应该更好理解。

先回顾一下bug原因:在预测时,predict_seq2seq函数中每次送入解码器的X都是1批量1时间步的,因此TransformerDecoder的forward函数第一步X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))将永远使用第一位置,导致在预测时解码器的位置编码失效。

为了使运行逻辑正确需要进行一些修改,我对TransformerDecoder类的init_state和forward函数进行了少量修改修复了这个问题。大致思想是在预测时把之前每一步的预测结果拼接在一起保存在TransformerDecoder对象里,在进行位置编码时使用完整的预测序列,然后再取出最后一步的X即可,此时X就具有正确的位置编码,并且运算成本不会显著增加。

def init_state(self, enc_outputs, enc_valid_lens, *args):
    self.seqX = None # modify
    return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

def forward(self, X, state):
    if not self.training: # modify
        self.seqX = X if self.seqX is None else torch.cat((self.seqX, X), dim=1)
        X = self.seqX

    X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

    if not self.training: # modify
        X = X[:, -1:, :]
    
    # 后面都和原书相同
    self._attention_weights = [[None] * len(self.blks) for _ in range (2)]

我觉得如果我一开始就这么改,可能更能解决大家的困惑,而我当时限于水平不足,把一个简单的bug给整的复杂化了,深感抱歉。

我认为,关于 predict_seq2seq 中位置编码的 BUG 修改两处即可:

  1. 在 PositionalEncoding 类的 forward 方法中增加对 开始索引 的支持;
  2. 在 TransformerDecoder 类的 forward 方法中根据 state[2] 中缓存的中间结果的 序列长度 确定位置编码的开始索引;

然而其实或许有个更好的解决方法:

确实,源码中仅仅对上一时间步的预测词X进行位置编码肯定是错的;
但是源码的DecoderBlock的X2 = self.attention1(X, key_values, key_values, dec_valid_lens)是没错的(因为query是上一个时间步的预测词,key和value是之前所有时间步的预测词);

因此,可以直接对位置编码进行修改,加一个offset即可:

class PositionalEncodingOffset(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        # 预生成 [1, max_len, d_model] 的位置编码矩阵 self.P
        P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / \
            torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        P[:, :, 0::2] = torch.sin(X)
        P[:, :, 1::2] = torch.cos(X)
        self.register_buffer('P', P)  # 不参与梯度
    def forward(self, X, offset: int = 0):
        # X: [B, T, D]; 取 P 的切片 [offset:offset+T]
        X = X + self.P[:, offset: offset + X.shape[1], :]
        return self.dropout(X)

然后对TransformerDecoder的forward也修改一下,加个offset(也就是现在在预测哪个词):

def forward(self, X, state):
    enc_outputs, enc_valid_lens, cache_list, t = state
    X = self.embedding(X) * math.sqrt(self.num_hiddens)
    if self.training:
        X = self.pos_encoding(X, offset=0)
    else:
        # 用 offset=t 的位置编码
        X = self.pos_encoding(X, offset=t)
    # 过各个解码块(内部会把当前步拼到各层自己的 KV cache 上)
    self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
    for i, blk in enumerate(self.blks):
        X, state = blk(X, state)
        self._attention_weights[0][i] = blk.attention1.attention.attention_weights
        self._attention_weights[1][i] = blk.attention2.attention.attention_weights
    # 推理时把 t 累加已处理的步数(通常是 1)
    if not self.training:
        state[3] = t + X.shape[1]
    return self.dense(X), state

既然源码位置编码的问题在于第i步的x位置编码编的是第0步的码,那么加个offset,编第i步的码就好了

d2l版本更新,原本的函数发生了变动(可能被删除)

位置编码修改,非常简洁

class PositionalEncoding(nn.Module):
“”“位置编码”“”

def __init__(self, num_hiddens, dropout, max_len=1000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(dropout)
    self.P = torch.zeros((1, max_len, num_hiddens))
    X = torch.arange(max_len, dtype=torch.float32). \
            reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
    self.P[:, :, 0::2] = torch.sin(X)
    self.P[:, :, 1::2] = torch.cos(X)
    self.t=0

def forward(self, X):
    X = X + self.P[:, self.t:X.shape[1]+self.t, :].to(X.device)
    if X.shape[1]==1:
        self.t+=1
    return self.dropout(X)

我建议大家本章的相关代码可以参考PyTorch API的架构。主要是这四个类
torch.nn.TransformerEncoderLayer()
torch.nn.TransformerEncoder()
torch.nn.TransformerDecoderLayer()
torch.nn.TransformerDecoder()
可以发现,从xxcoderLayer到xxcoder的时候,PyTorch是直接把构造好的一个Layer输入到xxcoder。


TransformerDecoder源码如上,可以看到,直接将输入的Layer复制了num_layers个然后保存。这时候是并没有把第i个layer这个信息保存到layer中的。确实也不需要,因为每个layer的工作基本是一样的。
D2L这引入i的原因就是为了在预测的时候,对于每一个Layer_i,在对一个序列运行(预测)到第n次的时候,都要把本次得到的x拼接在前面的n-1个x后,才能得到第n步的kv对(decoderLayer的第1个子层,看架构图很清晰)。
我的做法是直接在TransformerDecoderLayer增加一个字段,保存了这个kv对信息。每一次执行上下文都不一样,因此在Decoder执行init_state方法时,遍历decoderLayer清除这个字段信息即可。以下是我的部分代码:

class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim: int,
                 num_heads: int,
                 ffn_hiddens: int,
                 dropout: float = 0,
                 bias: bool = False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.masked_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, bias=bias, batch_first=True)
        self.add_norm1 = AddNorm(embed_dim, dropout)
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, bias=bias, batch_first=True)
        self.add_norm2 = AddNorm(embed_dim, dropout)
        self.ffn = PositionWiseFeedForwardNetwork(ffn_hiddens, embed_dim)
        self.add_norm3 = AddNorm(embed_dim, dropout)

        self.key_values: Tensor | None = None
        self.masked_attn_weights = None
        self.attn_weights = None

    def forward(self, x: Tensor, state: tuple[Tensor, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor]]:
        (encoder_x, valid_lens) = state
        if self.key_values is None:
            self.key_values = x
        else:
            self.key_values = torch.cat((self.key_values, x), dim=1)

        batch_size, seq_len, _ = x.shape
        attn_mask1 = torch.ones(seq_len, seq_len, device=x.device).triu(
            diagonal=1).repeat(self.num_heads * batch_size, 1, 1).bool() if self.training else None
        x1, attn1 = self.masked_attention(x, self.key_values, self.key_values, attn_mask=attn_mask1)
        self.masked_attn_weights = attn1
        x2 = self.add_norm1(x, x1)

        x3, attn2 = self.attention(x2, encoder_x, encoder_x,
                                   attn_mask=attn_mask(valid_lens, seq_len, encoder_x.shape[1],
                                                       num_heads=self.num_heads))
        self.attn_weights = attn2
        x4 = self.add_norm2(x2, x3)

        x5 = self.ffn(x4)
        return self.add_norm3(x4, x5), state

    def reset_attention_weights(self):
        self.masked_attn_weights = None
        self.attn_weights = None
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, decoder_layer: TransformerDecoderLayer, num_layers: int, vocab_size: int, dropout: float = 0):
        super().__init__()
        self.embed_dim = decoder_layer.embed_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, self.embed_dim)
        self.pos_encoding = PositionalEncoding(self.embed_dim, dropout=dropout)

        self.layers = nn.ModuleList(copy.deepcopy(decoder_layer) for _ in range(num_layers))

        self.dense = nn.LazyLinear(vocab_size)

        self._masked_attn_weights: list[Tensor] = []
        self._attn_weights: list[Tensor] = []

    @property
    def attention_weights(self) -> tuple[Tensor, Tensor]:
        """"
        获取并合并所有解码时间步的注意力权重
        :return: 注意力权重张量,形状为 (num_layers, batch_size, queries_num, keys_num)
        """
        return attn_merge(self._masked_attn_weights, self.num_layers), attn_merge(self._attn_weights, self.num_layers)

    def init_state(self, enc_all_outputs, *args):
        for layer in self.layers:
            layer.key_values = None
        return enc_all_outputs, *args

    def forward(self, x, state):
        ems = self.embedding(x)
        x = self.pos_encoding(ems * math.sqrt(self.embed_dim))

        for layer in self.layers:
            x, state = layer(x, state)
            self._masked_attn_weights.append(layer.masked_attn_weights)
            self._attn_weights.append(layer.attn_weights)
        return self.dense(x), state

    def reset_attention_weights(self):
        for layer in self.layers:
            layer.reset_attention_weights()
        self._masked_attn_weights = []
        self._attn_weights = []