Transformer

首先,目前的中文版代码相对英文版已经落后不少了。在中文版同步更新之前,建议后面的读者先参考英文版的代码。

然后说说大家讨论的位置编码bug对应的解法。
这个问题目前在英文版里也没有解决,所以这里也贴一下我的解法,应该是比较简单直观的:

问题分析:
预测模式下,由d2l.predict_seq2seq()函数将输入token一个一个单独传入TransformerDecoder进行预测,使得TransformerDecoder内部是获取不到每个token在序列中的位置的,所以其内部调用d2l.PositionalEncoding(试图根据token位置来映射出位置编码)是肯定会出bug的。

解决方案:
让predict_seq2seq()函数将token丢失的位置信息与token一起传入Decoder,并最终交给PositionalEncoding内部正确处理。即可

具体来说:
1. predict_seq2seq()函数(9.7节)里,调用decoder解码时添加offset参数

原代码

for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)

改为

for idx in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state, offset=idx)  // 传入表示位置的offset

2. TransformerDecoder.forward()函数,将offset透传给PositionalEncoding

原代码

def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

改为

def forward(self, X, state, offset=0):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens), offset) // 透传offset

3. PositionalEncoding.forward()函数(10.6节),支持根据offset计算出位置编码

原代码

def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)

改为

def forward(self, X, offset=0):
        X = X + self.P[:, offset:X.shape[1], :].to(X.device) // 根据offset计算出位置编码

如此修改,就可以解决预测场景下位置编码计算的bug,不影响其他方面的计算,也是在原有框架下改动最小的方案。

修改后的dec_valid_lens在推理时计算不对? 在推理时,输入X是之前的累计的序列,此时的dec_valid_lens应该是None,表示当前序列的长度。
在训练时,需要进行mask处理,所有需要arange。

//而我修改的代码没有再考虑这个,是因为我即使在预测时也总是送入至今为止预测的完整时间的X,设置为None的话,除了最后一个时间步,前面时间步的后续计算就出错了,因为它们会对“未来”产生注意力。
这个做法有问题,在推理时,原始代码,每次只有一次query,然后得到value。而你的代码相当于要做当前长度次的query。
在下面回复中,我认为@cddc提到的方法是有效的。

补充:完整开了@mashiro的代码,看起来是OK的。但推理时候的效率比较低,每次预测下个词是,需要对当前的序列整个进行一次attention,也就是要进行当前长度次的query,而实际只需要进行一次的query。

我认为最后一步有些问题,应该是
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device),预测的时候X是一维的,每次的X.shape[1]=1

residual connection主要是为了传梯度缓解梯度消失的,直接把X加上去并没有什么可学习的参数、模型不会学到如何走捷径“作弊”;而要学的参数都在attention里面,所以对要学的部分即attention做mask就行了

我有两个疑问,你的思路是把之前每步预测的结果保存起来,然后当作下一步的输入,而非像示例代码一样每步预测时只输入上一步的预测结果。我没理解错吧?
首先,你修改后的代码X2 = self.attention1(X, X, X, dec_valid_lens)中的X似乎都是t时刻及以前的输入,而示例中,自注意力的Q是当前的X,K与V才是t时刻及以前的输入;另外,在示例中,每个decoderblock之间传递的都是当前时刻的信息,而你修改好后传递的是所有历史时间步的信息

赞。你说的对,我代码的最后一步有个小bug。应该是你这样的才对:
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device)

特此更正

隔了很久,发现这个bug还有不少人在讨论,并且中文版也很久没有更新过了(看了眼英文版也还有这个bug)。我回来更新一下我当时的修改,虽然我当时的修改在逻辑上是正确的,但确实很多人提到了,预测时每次都用完整的当前预测序列做query进行注意力运算效率太低了。

我现在更新一下我的修改方案,我现在的修改方案只变更三处,应该更好理解。

先回顾一下bug原因:在预测时,predict_seq2seq函数中每次送入解码器的X都是1批量1时间步的,因此TransformerDecoder的forward函数第一步X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))将永远使用第一位置,导致在预测时解码器的位置编码失效。

为了使运行逻辑正确需要进行一些修改,我对TransformerDecoder类的init_state和forward函数进行了少量修改修复了这个问题。大致思想是在预测时把之前每一步的预测结果拼接在一起保存在TransformerDecoder对象里,在进行位置编码时使用完整的预测序列,然后再取出最后一步的X即可,此时X就具有正确的位置编码,并且运算成本不会显著增加。

def init_state(self, enc_outputs, enc_valid_lens, *args):
    self.seqX = None # modify
    return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

def forward(self, X, state):
    if not self.training: # modify
        self.seqX = X if self.seqX is None else torch.cat((self.seqX, X), dim=1)
        X = self.seqX

    X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

    if not self.training: # modify
        X = X[:, -1:, :]
    
    # 后面都和原书相同
    self._attention_weights = [[None] * len(self.blks) for _ in range (2)]

我觉得如果我一开始就这么改,可能更能解决大家的困惑,而我当时限于水平不足,把一个简单的bug给整的复杂化了,深感抱歉。