The classifier used in the text seems rather unnatural. A more natural way is to flatten the images, normalize them, then take the dot product as discussed in the text as a measure of similarity. The predicted label is the label of the average image that is more similar to the test image, hence the argmax.
# normalize matrices using broadcasting W = torch.stack([ave_0.flatten().t(), ave_1.flatten().t()], dim=1) W = W / torch.norm(W, dim=0).reshape(1, -1) X_test = X_test.reshape(-1, 784) X_test = X_test / torch.norm(X_test, dim=1).reshape(-1, 1) # predict and evaluate y_pred = torch.argmax(X_test @ W, dim=1) print((y_test == y_pred).type(torch.float).mean())
This obtains an accuracy of ~0.95.
What’s the interpretation of A^4 in exercise 7?