https://d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html
Your RnnModel class should really take advantage of the rnn_layer’s attributes.
I’d submit a PR, but it seems like the code here doesn’t match the one on github as of 2021-11
class RNNModel(nn.Module):
"""The RNN model."""
def __init__(self, rnn_layer, **kwargs):
super(RNNModel, self).__init__(**kwargs)
self.rnn = rnn_layer
# bidirectional RNNs will be introduced later,
if self.rnn.bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
self.linear = nn.Linear(
self.num_directions * self.rnn.hidden_size, self.rnn.input_size)
def forward(self, inputs, state):
"""
The fully connected layer will first change the shape of `Y` to
(`num_steps` * `batch_size`, `num_hiddens`).
Its output shape is (`num_steps` * `batch_size`, `vocab_size`).
"""
X = F.one_hot(inputs.T.long(), self.rnn.input_size)
X = X.to(torch.float32)
Y, state = self.rnn(X, state)
# Output is the hidden state at the time step
output = self.linear(Y.reshape((-1, Y.shape[-1])))
return output, state
def begin_state(self, device, batch_size=1):
"""
`nn.LSTM` takes a tuple of hidden states
"""
tensor = torch.zeros((
self.num_directions * self.rnn.num_layers, batch_size, self.rnn.hidden_size),
device=device)
if isinstance(self.rnn, nn.LSTM):
return (tensor, tensor)
else:
# `nn.GRU` takes a tensor as hidden state
return tensor