Update.
[pytorch.git] / flatparam.py
index e0627b2..0b61cf1 100755 (executable)
@@ -5,12 +5,13 @@ from torch import nn
 
 ######################################################################
 
-def _flatparam(model, whole, already = [], offset = 0):
+
+def _flatparam(model, whole, already=[], offset=0):
     for v in model._parameters:
         p = model._parameters[v]
         e = p.numel()
         s = p.size()
-        model._parameters[v] = whole[offset:offset+e].view(s)
+        model._parameters[v] = whole[offset : offset + e].view(s)
         with torch.no_grad():
             model._parameters[v].copy_(p)
         offset += e
@@ -20,43 +21,42 @@ def _flatparam(model, whole, already = [], offset = 0):
             offset = _flatparam(m, whole, already, offset)
     return offset
 
+
 def flatparam(model):
     n = sum(p.numel() for p in model.parameters())
-    whole = next(model.parameters()).new(n) # Get same device and dtype
+    whole = next(model.parameters()).new(n)  # Get same device and dtype
     whole.requires_grad_()
     _flatparam(model, whole)
-    model.parameters = lambda: iter([ whole ])
+    model.parameters = lambda: iter([whole])
+
 
 ######################################################################
 
 model = nn.Sequential(
     nn.Linear(2, 4),
     nn.ReLU(),
-    nn.Sequential(
-        nn.Linear(4, 4),
-        nn.ReLU(), nn.Linear(4, 2)
-    )
+    nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)),
 )
 
 ######################################################################
 
-print('Before:')
+print("Before:")
 for p in model.parameters():
     print(p.size(), p.storage().size())
 
 flatparam(model)
 
-print('After:')
+print("After:")
 for p in model.parameters():
     print(p.size(), p.storage().size())
 
 ######################################################################
 
-print('Check:')
+print("Check:")
 
 input = torch.rand(100, 2)
 targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
+optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 mse = nn.MSELoss()
 
 for e in range(10):