I have a question about the following function,
def squared_loss(y_hat, y): #@save """Squared loss.""" return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
y_hat is generated by
linreg(), its shape will be
batch_size * 1, which is the same as
y, then why is
y.reshape(y_hat.shape) necessary? I get wrong results after deleting it.