Geometry and Linear Algebraic Operations

The classifier used in the text seems rather unnatural. A more natural way is to flatten the images, normalize them, then take the dot product as discussed in the text as a measure of similarity. The predicted label is the label of the average image that is more similar to the test image, hence the argmax.

# normalize matrices using broadcasting
W = torch.stack([ave_0.flatten().t(), ave_1.flatten().t()], dim=1)
W = W / torch.norm(W, dim=0).reshape(1, -1)
X_test = X_test.reshape(-1, 784)
X_test = X_test / torch.norm(X_test, dim=1).reshape(-1, 1)

# predict and evaluate
y_pred = torch.argmax(X_test @ W, dim=1)
print((y_test == y_pred).type(torch.float).mean())

This obtains an accuracy of ~0.95. :grinning:

1 Like

What’s the interpretation of A^4 in exercise 7?