同一个数据集构造一次就行了。。。。。。。
这里有一个地方可以优化,损失函数的构建原来是l.sum().backward()这样的损失函数是该批次损失的和计算损失,这样受批次大小的影响比较大。测试的时候batch是32。在进行实际的时候,batch调整到128。稳定性就不好了,sum`的梯度更新幅度较大,可能会导致训练不稳定。学习率增加到0.1,两个因素叠加,所以非常不稳定。
找一下训练函数,l.sum().backward(),改成,l.mean().backward()就稳定了