I have a question about the following function,
def squared_loss(y_hat, y):  #@save
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
Since y_hat is generated by linreg(), its shape will be batch_size * 1, which is the same as y, then why is y.reshape(y_hat.shape) necessary? I get wrong results after deleting it.