# Pretraining BERT

Hi,

I am wondering if there is an error in PytorchĀ“s _get_batch_loss_bert():

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)

Before, it was defined loss = nn.CrossEntropyLoss(), so it adopts the default reduction=āmeanā. But my impression is that we should use reduction=None instead, in order to average it based on the variable num of real tokens (mlm_weights_X.sum()) and not the real tokens + pad tokens as it currently seems to be doing?

In terms of dimensionality: mlm_Y_hat = {n,tau,v}, mlm_Y={n,tau} and mlm_weights_X={n,tau}. Currently loss(mlm_Y_hatā¦) seems to yield a scalar, which is then multiplied by mlm_weights_X.reshape(-1, 1) of dim {nĀ·tau} which gives a vector of dim {nĀ·tau}. In the 2nd line it is then sumed up and divided by the num of non-zero components in mlm_weights_X, effectively giving back the original scalar loss(mlm_Y_hatā¦) which is an average over the real and pad tokens.

I think if for MLM it had been defined loss = nn.CrossEntropyLoss(reduction=None), then loss(mlm_Y_hatā¦) would yield a tensor of dim {nĀ·tau} which multiplied by mlm_weights_X.reshape(-1, 1) of dim {nĀ·tau} would give a vector of dim {nĀ·tau}. The 2nd line would then sum it up and divide it by the num of real tokens, without the pad tokens affecting the average.

1 Like

I agree. The `CrossEntropyLoss` object should have been instantiated with the parameter `reduction` specified as `none`. Otherwise it will average the loss over both valid and invalid (padded) tokens. The correct instantiation statement is

``````loss = nn.CrossEntropyLoss(reduction='none')
``````

Also the code filtering out losses for padding tokens

``````mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1, 1)
``````

should be changed to

``````mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1)
``````

After this (revised) line of code has been executed, `mlm_l` would have shape of `(5120,)` (here, batch size=512, maximum number of masked tokens=10, 512 times 10 = 5120)

2 Likes

Hello, I used this method to pretrain a BERT model on a humungous DNA corpus. The āproblemā is that it is finishing pertaining in just 15 minutes for the entire dataset whereas a BERT from scratch would take at least a few days (along with the processing capabilities). Could someone let me know why this is happening and if the model is learning or not?

I agree with you, the original code will output the same loss as loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)).