Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 06:36:59 +0000 (08:36 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 10 Sep 2020 06:36:59 +0000 (08:36 +0200)
flatparam.py [new file with mode: 0755]

diff --git a/flatparam.py b/flatparam.py
new file mode 100755 (executable)
index 0000000..3c20153
--- /dev/null
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+
+import torch, torchvision
+from torch import nn
+
+######################################################################
+
+def flatparam(model):
+    with torch.no_grad():
+        n = sum(p.numel() for p in model.parameters())
+        big = next(model.parameters()).new(n) # Get same device and dtype
+        k = 0
+        for p in model.parameters():
+            tmp = p.new(0).set_(p)
+            p.set_(big.storage(), k, p.size()).copy_(tmp)
+            k += p.numel()
+
+######################################################################
+
+model = nn.Sequential(
+    nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
+)
+
+print('Before:')
+for p in model.parameters():
+    print(p.size(), p.storage().size())
+
+flatparam(model)
+
+print('After:')
+for p in model.parameters():
+    print(p.size(), p.storage().size())
+
+######################################################################
+
+print('Check:')
+
+input = torch.rand(100, 2)
+targets = torch.rand(100, 2)
+optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
+mse = nn.MSELoss()
+
+for e in range(10):
+    output = model(input)
+    loss = mse(output, targets)
+    print(e, loss.item())
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()