A problem in chapter 3.2, linear regression implementation

I have a question about the following function,

def squared_loss(y_hat, y):  #@save
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

Since y_hat is generated by linreg(), its shape will be batch_size * 1, which is the same as y, then why is y.reshape(y_hat.shape) necessary? I get wrong results after deleting it.

Ask behind the chapter please.
@oliver