Pretraining BERT

https://d2l.ai/chapter_natural-language-processing-pretraining/bert-pretraining.html

Hi,

I am wondering if there is an error in Pytorch´s _get_batch_loss_bert():

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)

Before, it was defined loss = nn.CrossEntropyLoss(), so it adopts the default reduction=‘mean’. But my impression is that we should use reduction=None instead, in order to average it based on the variable num of real tokens (mlm_weights_X.sum()) and not the real tokens + pad tokens as it currently seems to be doing?

In terms of dimensionality: mlm_Y_hat = {n,tau,v}, mlm_Y={n,tau} and mlm_weights_X={n,tau}. Currently loss(mlm_Y_hat…) seems to yield a scalar, which is then multiplied by mlm_weights_X.reshape(-1, 1) of dim {n·tau} which gives a vector of dim {n·tau}. In the 2nd line it is then sumed up and divided by the num of non-zero components in mlm_weights_X, effectively giving back the original scalar loss(mlm_Y_hat…) which is an average over the real and pad tokens.

I think if for MLM it had been defined loss = nn.CrossEntropyLoss(reduction=None), then loss(mlm_Y_hat…) would yield a tensor of dim {n·tau} which multiplied by mlm_weights_X.reshape(-1, 1) of dim {n·tau} would give a vector of dim {n·tau}. The 2nd line would then sum it up and divide it by the num of real tokens, without the pad tokens affecting the average.

I agree. The CrossEntropyLoss object should have been instantiated with the parameter reduction specified as none. Otherwise it will average the loss over both valid and invalid (padded) tokens. The correct instantiation statement is

loss = nn.CrossEntropyLoss(reduction='none')

Also the code filtering out losses for padding tokens

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
             mlm_weights_X.reshape(-1, 1)

should be changed to

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
        mlm_weights_X.reshape(-1)

After this (revised) line of code has been executed, mlm_l would have shape of (5120,) (here, batch size=512, maximum number of masked tokens=10, 512 times 10 = 5120)