层和块

书中采用的自定义顺序块的方式好像在新版的torch中不适用,会产生报错。

TypeError                                 Traceback (most recent call last)
Cell In[23], line 2
      1 net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
----> 2 net(X)

File ~\.conda\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~\.conda\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[22], line 8, in MySequential.forward(self, X)
      7 def forward(self, X):
----> 8     for module in self.modules:
      9         X = module(X)
     10     return X

TypeError: 'method' object is not iterable

要使用下面这种方式定义顺序块:

class MySequential(nn.Module):
    def __init__(self, *args):
        #super().__init__()这种方式已经不适用
        super(MySequential, self).__init__()
        #for idx, module in enumerate(args):这种方式已经不适用
        self.module_list = nn.ModuleList(args)

    def forward(self, X):
        for module in self.module_list:
            X = module(X)
        return X