--- /dev/null
+#!/usr/bin/env python
+
+import torch, torchvision
+from torch import nn
+
+######################################################################
+
+def flatparam(model):
+ with torch.no_grad():
+ n = sum(p.numel() for p in model.parameters())
+ big = next(model.parameters()).new(n) # Get same device and dtype
+ k = 0
+ for p in model.parameters():
+ tmp = p.new(0).set_(p)
+ p.set_(big.storage(), k, p.size()).copy_(tmp)
+ k += p.numel()
+
+######################################################################
+
+model = nn.Sequential(
+ nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
+)
+
+print('Before:')
+for p in model.parameters():
+ print(p.size(), p.storage().size())
+
+flatparam(model)
+
+print('After:')
+for p in model.parameters():
+ print(p.size(), p.storage().size())
+
+######################################################################
+
+print('Check:')
+
+input = torch.rand(100, 2)
+targets = torch.rand(100, 2)
+optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
+mse = nn.MSELoss()
+
+for e in range(10):
+ output = model(input)
+ loss = mse(output, targets)
+ print(e, loss.item())
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()