Transformer

首先,目前的中文版代码相对英文版已经落后不少了。在中文版同步更新之前,建议后面的读者先参考英文版的代码。

然后说说大家讨论的位置编码bug对应的解法。
这个问题目前在英文版里也没有解决,所以这里也贴一下我的解法,应该是比较简单直观的:

问题分析:
预测模式下,由d2l.predict_seq2seq()函数将输入token一个一个单独传入TransformerDecoder进行预测,使得TransformerDecoder内部是获取不到每个token在序列中的位置的,所以其内部调用d2l.PositionalEncoding(试图根据token位置来映射出位置编码)是肯定会出bug的。

解决方案:
让predict_seq2seq()函数将token丢失的位置信息与token一起传入Decoder,并最终交给PositionalEncoding内部正确处理。即可

具体来说:
1. predict_seq2seq()函数(9.7节)里,调用decoder解码时添加offset参数

原代码

for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)

改为

for idx in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state, offset=idx)  // 传入表示位置的offset

2. TransformerDecoder.forward()函数,将offset透传给PositionalEncoding

原代码

def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

改为

def forward(self, X, state, offset=0):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens), offset) // 透传offset

3. PositionalEncoding.forward()函数(10.6节),支持根据offset计算出位置编码

原代码

def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)

改为

def forward(self, X, offset=0):
        X = X + self.P[:, offset:X.shape[1], :].to(X.device) // 根据offset计算出位置编码

如此修改,就可以解决预测场景下位置编码计算的bug,不影响其他方面的计算,也是在原有框架下改动最小的方案。

修改后的dec_valid_lens在推理时计算不对? 在推理时,输入X是之前的累计的序列,此时的dec_valid_lens应该是None,表示当前序列的长度。
在训练时,需要进行mask处理,所有需要arange。

//而我修改的代码没有再考虑这个,是因为我即使在预测时也总是送入至今为止预测的完整时间的X,设置为None的话,除了最后一个时间步,前面时间步的后续计算就出错了,因为它们会对“未来”产生注意力。
这个做法有问题,在推理时,原始代码,每次只有一次query,然后得到value。而你的代码相当于要做当前长度次的query。
在下面回复中,我认为@cddc提到的方法是有效的。

补充:完整开了@mashiro的代码,看起来是OK的。但推理时候的效率比较低,每次预测下个词是,需要对当前的序列整个进行一次attention,也就是要进行当前长度次的query,而实际只需要进行一次的query。

我认为最后一步有些问题,应该是
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device),预测的时候X是一维的,每次的X.shape[1]=1

residual connection主要是为了传梯度缓解梯度消失的,直接把X加上去并没有什么可学习的参数、模型不会学到如何走捷径“作弊”;而要学的参数都在attention里面,所以对要学的部分即attention做mask就行了

我有两个疑问,你的思路是把之前每步预测的结果保存起来,然后当作下一步的输入,而非像示例代码一样每步预测时只输入上一步的预测结果。我没理解错吧?
首先,你修改后的代码X2 = self.attention1(X, X, X, dec_valid_lens)中的X似乎都是t时刻及以前的输入,而示例中,自注意力的Q是当前的X,K与V才是t时刻及以前的输入;另外,在示例中,每个decoderblock之间传递的都是当前时刻的信息,而你修改好后传递的是所有历史时间步的信息

赞。你说的对,我代码的最后一步有个小bug。应该是你这样的才对:
X = X + self.P[:, offset:offset+X.shape[1], :].to(X.device)

特此更正