https://d2l.ai/chapter_natural-language-processing-pretraining/bert-pretraining.html
Hi,
I am wondering if there is an error in Pytorch“s _get_batch_loss_bert():
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
Before, it was defined loss = nn.CrossEntropyLoss(), so it adopts the default reduction=āmeanā. But my impression is that we should use reduction=None instead, in order to average it based on the variable num of real tokens (mlm_weights_X.sum()) and not the real tokens + pad tokens as it currently seems to be doing?
In terms of dimensionality: mlm_Y_hat = {n,tau,v}, mlm_Y={n,tau} and mlm_weights_X={n,tau}. Currently loss(mlm_Y_hatā¦) seems to yield a scalar, which is then multiplied by mlm_weights_X.reshape(-1, 1) of dim {nĀ·tau} which gives a vector of dim {nĀ·tau}. In the 2nd line it is then sumed up and divided by the num of non-zero components in mlm_weights_X, effectively giving back the original scalar loss(mlm_Y_hatā¦) which is an average over the real and pad tokens.
I think if for MLM it had been defined loss = nn.CrossEntropyLoss(reduction=None), then loss(mlm_Y_hatā¦) would yield a tensor of dim {nĀ·tau} which multiplied by mlm_weights_X.reshape(-1, 1) of dim {nĀ·tau} would give a vector of dim {nĀ·tau}. The 2nd line would then sum it up and divide it by the num of real tokens, without the pad tokens affecting the average.
I agree. The CrossEntropyLoss
object should have been instantiated with the parameter reduction
specified as none
. Otherwise it will average the loss over both valid and invalid (padded) tokens. The correct instantiation statement is
loss = nn.CrossEntropyLoss(reduction='none')
Also the code filtering out losses for padding tokens
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1, 1)
should be changed to
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1)
After this (revised) line of code has been executed, mlm_l
would have shape of (5120,)
(here, batch size=512, maximum number of masked tokens=10, 512 times 10 = 5120)
Hello, I used this method to pretrain a BERT model on a humungous DNA corpus. The āproblemā is that it is finishing pertaining in just 15 minutes for the entire dataset whereas a BERT from scratch would take at least a few days (along with the processing capabilities). Could someone let me know why this is happening and if the model is learning or not?
I agree with you, the original code will output the same loss as loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)).