From a1c89c4da439a4ad48d8f79b6697a2108be4b514 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 10 Sep 2020 13:23:16 +0200 Subject: [PATCH] Following @apaszke's suggestion --- flatparam.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/flatparam.py b/flatparam.py index 3c20153..fbede34 100755 --- a/flatparam.py +++ b/flatparam.py @@ -5,27 +5,46 @@ from torch import nn ###################################################################### +def _flatparam(model, whole, already = [], offset = 0): + for v in model._parameters: + p = model._parameters[v] + e = p.numel() + s = p.size() + model._parameters[v] = whole[offset:offset+e].view(s) + with torch.no_grad(): + model._parameters[v].copy_(p) + offset += e + already.append(model) + for m in model.modules(): + if m not in already: + offset = _flatparam(m, whole, already, offset) + return offset + def flatparam(model): - with torch.no_grad(): - n = sum(p.numel() for p in model.parameters()) - big = next(model.parameters()).new(n) # Get same device and dtype - k = 0 - for p in model.parameters(): - tmp = p.new(0).set_(p) - p.set_(big.storage(), k, p.size()).copy_(tmp) - k += p.numel() + n = sum(p.numel() for p in model.parameters()) + whole = next(model.parameters()).new(n) # Get same device and dtype + whole.requires_grad_() + _flatparam(model, whole, [], 0) + return whole ###################################################################### model = nn.Sequential( - nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2) + nn.Linear(2, 4), + nn.ReLU(), + nn.Sequential( + nn.Linear(4, 4), + nn.ReLU(), nn.Linear(4, 2) + ) ) +###################################################################### + print('Before:') for p in model.parameters(): print(p.size(), p.storage().size()) -flatparam(model) +whole = flatparam(model) print('After:') for p in model.parameters(): @@ -37,7 +56,7 @@ print('Check:') input = torch.rand(100, 2) targets = torch.rand(100, 2) -optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2) +optimizer = torch.optim.SGD([ whole ], lr = 1e-2) mse = nn.MSELoss() for e in range(10): -- 2.20.1