My solutions to the exs: 6.1
My Solution to Q3:
class DaisyX(nn.Module):
def __init__(self, genericModule: nn.Module, chain_length=5):
super().__init__()
for idx in range(chain_length):
self: nn.Module
self.add_module(str(idx)+ genericModule.__name__, genericModule())
def forward(self, X):
for m in self.children():
X = m(X)
return X
class Increment(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return (X + 1)
net = DaisyX(Increment, 5)
X = torch.zeros((2, 2))
net(X)