n = sum(p.numel() for p in model.parameters())
whole = next(model.parameters()).new(n) # Get same device and dtype
whole.requires_grad_()
- _flatparam(model, whole, [], 0)
- return whole
+ _flatparam(model, whole)
+ model.parameters = lambda: iter([ whole ])
######################################################################
for p in model.parameters():
print(p.size(), p.storage().size())
-whole = flatparam(model)
+flatparam(model)
print('After:')
for p in model.parameters():
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD([ whole ], lr = 1e-2)
+optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
mse = nn.MSELoss()
for e in range(10):