From 7935bad172ac27fa77d28dc8bf7147f5b5aabaaa Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 10 Sep 2020 08:36:59 +0200 Subject: [PATCH] Initial commit. --- flatparam.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100755 flatparam.py diff --git a/flatparam.py b/flatparam.py new file mode 100755 index 0000000..3c20153 --- /dev/null +++ b/flatparam.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import torch, torchvision +from torch import nn + +###################################################################### + +def flatparam(model): + with torch.no_grad(): + n = sum(p.numel() for p in model.parameters()) + big = next(model.parameters()).new(n) # Get same device and dtype + k = 0 + for p in model.parameters(): + tmp = p.new(0).set_(p) + p.set_(big.storage(), k, p.size()).copy_(tmp) + k += p.numel() + +###################################################################### + +model = nn.Sequential( + nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2) +) + +print('Before:') +for p in model.parameters(): + print(p.size(), p.storage().size()) + +flatparam(model) + +print('After:') +for p in model.parameters(): + print(p.size(), p.storage().size()) + +###################################################################### + +print('Check:') + +input = torch.rand(100, 2) +targets = torch.rand(100, 2) +optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2) +mse = nn.MSELoss() + +for e in range(10): + output = model(input) + loss = mse(output, targets) + print(e, loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() -- 2.20.1