他这里想表达的是,再都除4的情况下,第一个是第二个的2倍,那么第一个有效长度是第二个的俩倍,因为4个时间步loss都是一样的。
可以换LSTM的。
Encoder
Encoder中把rnn直接换成LSTM就可以。
class Seq2SeqEncoder(d2l.Encoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqEncoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers,
dropout=dropout)
def forward(self, X, *args):
X = self.embedding(X)
X = X.permute(1,0,2)
output, state = self.rnn(X)
return output, state
Decoder
问题主要是Decoder中的forward里的state。由于LSTM的state输出 H
和 C
, 所以需要state[0][-1]
拿到H
中的最后一层的隐藏层,在concat就可以了。
class Seq2SeqDecoder(d2l.Decoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
droput=0, **kwargs):
super(Seq2SeqDecoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.LSTM(embed_size + num_hiddens, num_hiddens, num_layers,
dropout=droput)
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, *args):
return enc_outputs[1]
def forward(self, X, state):
X = self.embedding(X).permute(1,0,2)
context = state[0][-1].repeat(X.shape[0], 1, 1)
X_and_context = torch.cat((X, context), 2)
output, state = self.rnn(X_and_context, state)
output = self.dense(output).permute(1, 0, 2)
return output, state
其他都一样
@Sheepxaviera
这个应该就是新版的EncoderDecoder的forward定义和以前不一样了的问题。
你把返回的那个"Y_hat, _" 后面的下划线直接去掉就可以啦。
像这样
Y_hat = net(X, dec_input, X_valid_len)