https://d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html
Your RnnModel class should really take advantage of the rnn_layer’s attributes.
I’d submit a PR, but it seems like the code here doesn’t match the one on github as of 2021-11
class RNNModel(nn.Module):
"""The RNN model."""
def __init__(self, rnn_layer, **kwargs):
super(RNNModel, self).__init__(**kwargs)
self.rnn = rnn_layer
# bidirectional RNNs will be introduced later,
if self.rnn.bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
self.linear = nn.Linear(
self.num_directions * self.rnn.hidden_size, self.rnn.input_size)
def forward(self, inputs, state):
"""
The fully connected layer will first change the shape of `Y` to
(`num_steps` * `batch_size`, `num_hiddens`).
Its output shape is (`num_steps` * `batch_size`, `vocab_size`).
"""
X = F.one_hot(inputs.T.long(), self.rnn.input_size)
X = X.to(torch.float32)
Y, state = self.rnn(X, state)
# Output is the hidden state at the time step
output = self.linear(Y.reshape((-1, Y.shape[-1])))
return output, state
def begin_state(self, device, batch_size=1):
"""
`nn.LSTM` takes a tuple of hidden states
"""
tensor = torch.zeros((
self.num_directions * self.rnn.num_layers, batch_size, self.rnn.hidden_size),
device=device)
if isinstance(self.rnn, nn.LSTM):
return (tensor, tensor)
else:
# `nn.GRU` takes a tensor as hidden state
return tensor
Let’s not use inheritance here. It’s over-engineered and introduce unnecessary cognitive overhead for a tutorial, especially you are inheriting from RNNLMScratch.
I sometimes find implementations in (current) Chinese version much better for learning purpose.
I’m not against reusable code, but that shouldn’t go against readability, especially for code in this book.
definitely agreed! i find english version for study new knowledge (torch api in chinese version is outdated), but these RNN chapters make me confused about how true data flow goes because of the code design.luckily and finally, i take hours and understand
.to be honest, there should be a understandable class graph or data flow graph for RNN.