Pretraining BERT

https://d2l.ai/chapter_natural-language-processing-pretraining/bert-pretraining.html

Hi,

I am wondering if there is an error in Pytorch´s _get_batch_loss_bert():

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)

Before, it was defined loss = nn.CrossEntropyLoss(), so it adopts the default reduction=‘mean’. But my impression is that we should use reduction=None instead, in order to average it based on the variable num of real tokens (mlm_weights_X.sum()) and not the real tokens + pad tokens as it currently seems to be doing?

In terms of dimensionality: mlm_Y_hat = {n,tau,v}, mlm_Y={n,tau} and mlm_weights_X={n,tau}. Currently loss(mlm_Y_hat…) seems to yield a scalar, which is then multiplied by mlm_weights_X.reshape(-1, 1) of dim {n·tau} which gives a vector of dim {n·tau}. In the 2nd line it is then sumed up and divided by the num of non-zero components in mlm_weights_X, effectively giving back the original scalar loss(mlm_Y_hat…) which is an average over the real and pad tokens.

I think if for MLM it had been defined loss = nn.CrossEntropyLoss(reduction=None), then loss(mlm_Y_hat…) would yield a tensor of dim {n·tau} which multiplied by mlm_weights_X.reshape(-1, 1) of dim {n·tau} would give a vector of dim {n·tau}. The 2nd line would then sum it up and divide it by the num of real tokens, without the pad tokens affecting the average.