Following @apaszke's suggestion
authorFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 11:23:16 +0000 (13:23 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 11:23:16 +0000 (13:23 +0200)
flatparam.py

index 3c20153..fbede34 100755 (executable)
@@ -5,27 +5,46 @@ from torch import nn
 
 ######################################################################
 
+def _flatparam(model, whole, already = [], offset = 0):
+    for v in model._parameters:
+        p = model._parameters[v]
+        e = p.numel()
+        s = p.size()
+        model._parameters[v] = whole[offset:offset+e].view(s)
+        with torch.no_grad():
+            model._parameters[v].copy_(p)
+        offset += e
+    already.append(model)
+    for m in model.modules():
+        if m not in already:
+            offset = _flatparam(m, whole, already, offset)
+    return offset
+
 def flatparam(model):
-    with torch.no_grad():
-        n = sum(p.numel() for p in model.parameters())
-        big = next(model.parameters()).new(n) # Get same device and dtype
-        k = 0
-        for p in model.parameters():
-            tmp = p.new(0).set_(p)
-            p.set_(big.storage(), k, p.size()).copy_(tmp)
-            k += p.numel()
+    n = sum(p.numel() for p in model.parameters())
+    whole = next(model.parameters()).new(n) # Get same device and dtype
+    whole.requires_grad_()
+    _flatparam(model, whole, [], 0)
+    return whole
 
 ######################################################################
 
 model = nn.Sequential(
-    nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
+    nn.Linear(2, 4),
+    nn.ReLU(),
+    nn.Sequential(
+        nn.Linear(4, 4),
+        nn.ReLU(), nn.Linear(4, 2)
+    )
 )
 
+######################################################################
+
 print('Before:')
 for p in model.parameters():
     print(p.size(), p.storage().size())
 
-flatparam(model)
+whole = flatparam(model)
 
 print('After:')
 for p in model.parameters():
@@ -37,7 +56,7 @@ print('Check:')
 
 input = torch.rand(100, 2)
 targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
+optimizer = torch.optim.SGD([ whole ], lr = 1e-2)
 mse = nn.MSELoss()
 
 for e in range(10):