Niiice.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 16:57:09 +0000 (18:57 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 16:57:09 +0000 (18:57 +0200)
flatparam.py

index fbede34..e0627b2 100755 (executable)
@@ -24,8 +24,8 @@ def flatparam(model):
     n = sum(p.numel() for p in model.parameters())
     whole = next(model.parameters()).new(n) # Get same device and dtype
     whole.requires_grad_()
-    _flatparam(model, whole, [], 0)
-    return whole
+    _flatparam(model, whole)
+    model.parameters = lambda: iter([ whole ])
 
 ######################################################################
 
@@ -44,7 +44,7 @@ print('Before:')
 for p in model.parameters():
     print(p.size(), p.storage().size())
 
-whole = flatparam(model)
+flatparam(model)
 
 print('After:')
 for p in model.parameters():
@@ -56,7 +56,7 @@ print('Check:')
 
 input = torch.rand(100, 2)
 targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD([ whole ], lr = 1e-2)
+optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
 mse = nn.MSELoss()
 
 for e in range(10):