I have a question about the following function,
def squared_loss(y_hat, y): #@save
"""Squared loss."""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
Since y_hat
is generated by linreg()
, its shape will be batch_size * 1
, which is the same as y
, then why is y.reshape(y_hat.shape)
necessary? I get wrong results after deleting it.