From 5db6798d929c15e1517ec10c1a9211f870ec977e Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 10 Sep 2020 18:57:09 +0200 Subject: [PATCH] Niiice. --- flatparam.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flatparam.py b/flatparam.py index fbede34..e0627b2 100755 --- a/flatparam.py +++ b/flatparam.py @@ -24,8 +24,8 @@ def flatparam(model): n = sum(p.numel() for p in model.parameters()) whole = next(model.parameters()).new(n) # Get same device and dtype whole.requires_grad_() - _flatparam(model, whole, [], 0) - return whole + _flatparam(model, whole) + model.parameters = lambda: iter([ whole ]) ###################################################################### @@ -44,7 +44,7 @@ print('Before:') for p in model.parameters(): print(p.size(), p.storage().size()) -whole = flatparam(model) +flatparam(model) print('After:') for p in model.parameters(): @@ -56,7 +56,7 @@ print('Check:') input = torch.rand(100, 2) targets = torch.rand(100, 2) -optimizer = torch.optim.SGD([ whole ], lr = 1e-2) +optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2) mse = nn.MSELoss() for e in range(10): -- 2.39.5