The last two lines of the function
_, topk = torch.topk(cos, k=k) return topk, [cos[int(i)] for i in topk]
can be simplified to
vals, topk = torch.topk(cos, k=k) return topk, vals
The original code throws away the top similarity values then recovers them using list comprehension, which seems pretty redundant to me.