首先,目前的中文版代码相对英文版已经落后不少了。在中文版同步更新之前,建议后面的读者先参考英文版的代码。
然后说说大家讨论的位置编码bug对应的解法。
这个问题目前在英文版里也没有解决,所以这里也贴一下我的解法,应该是比较简单直观的:
问题分析:
预测模式下,由d2l.predict_seq2seq()函数将输入token一个一个单独传入TransformerDecoder进行预测,使得TransformerDecoder内部是获取不到每个token在序列中的位置的,所以其内部调用d2l.PositionalEncoding(试图根据token位置来映射出位置编码)是肯定会出bug的。
解决方案:
让predict_seq2seq()函数将token丢失的位置信息与token一起传入Decoder,并最终交给PositionalEncoding内部正确处理。即可
具体来说:
1. predict_seq2seq()函数(9.7节)里,调用decoder解码时添加offset参数
原代码
for _ in range(num_steps):
Y, dec_state = net.decoder(dec_X, dec_state)
改为
for idx in range(num_steps):
Y, dec_state = net.decoder(dec_X, dec_state, offset=idx) // 传入表示位置的offset
2. TransformerDecoder.forward()函数,将offset透传给PositionalEncoding
原代码
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
改为
def forward(self, X, state, offset=0):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens), offset) // 透传offset
3. PositionalEncoding.forward()函数(10.6节),支持根据offset计算出位置编码
原代码
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
改为
def forward(self, X, offset=0):
X = X + self.P[:, offset:X.shape[1], :].to(X.device) // 根据offset计算出位置编码
如此修改,就可以解决预测场景下位置编码计算的bug,不影响其他方面的计算,也是在原有框架下改动最小的方案。