Bidirectional Encoder Representations from Transformers (BERT)

1 Like

I don’t quite understand the code below.

# PyTorch by default won't flatten the tensor as seen in mxnet where, if
# flatten=True, all but the first axis of input data are collapsed together
encoded_X = torch.flatten(encoded_X, start_dim=1)
# input_shape for NSP: (batch size, `num_hiddens`)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)

Might want to change it to something similar to nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :])) in Section 6. Some quotes here:

Hence, the output layer ( self.output ) of the MLP classifier takes X as the input, where X is the output of the MLP hidden layer whose input is the encoded <cls> token.