许多代码在m1芯片上无法直接使用的问题

因为我只有使用M1芯片的macbook,不确定在m2m3上是否有问题。不过在我的电脑上有问题是肯定的。
问题的根源是pytorch中torch.argmax()这个函数,它的本意是取出最大值的索引,然而,如果你使用的是"device='mps"而不是’cuda’的话,很多时候它会取出一个明显的错误值“-9223372036854775808”,而如果你使用’cpu’则一切正常,这说明问题就在于使用的GPU。
这个问题可能会出现在很多章节,事实上,所有使用了torch.argmax()函数的部分都可能会出现这个问题。包括不限于:
1. 在mps上训练时无法得出正确的accurcy:


2. rnn从零实现中,函数predict_ch8():报错 “IndexError: list index out of range”

有一个可用的解决方法是使用torch.max()函数作为替代,它的作用是返回 最大值及其索引,而前面提到了,argmax()函数的作用正是返回索引,所以我们可以这样做:
#outputs.append(int(y.argmax(dim=1).reshape(1)))
_, y = y.max(dim=1)
outputs.append(int(y.reshape(1)))
这样就可以解决问题了,所有的argmax()引起的问题都可以用类似的方法解决。

我觉得如果有新版的话应该在此处加以说明?毕竟这个问题还是挺麻烦的