书中采用的自定义顺序块的方式好像在新版的torch中不适用,会产生报错。
TypeError Traceback (most recent call last)
Cell In[23], line 2
1 net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
----> 2 net(X)
File ~\.conda\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\.conda\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
Cell In[22], line 8, in MySequential.forward(self, X)
7 def forward(self, X):
----> 8 for module in self.modules:
9 X = module(X)
10 return X
TypeError: 'method' object is not iterable
要使用下面这种方式定义顺序块:
class MySequential(nn.Module):
def __init__(self, *args):
#super().__init__()这种方式已经不适用
super(MySequential, self).__init__()
#for idx, module in enumerate(args):这种方式已经不适用
self.module_list = nn.ModuleList(args)
def forward(self, X):
for module in self.module_list:
X = module(X)
return X