projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
7935bad
)
Following @apaszke's suggestion
author
Francois Fleuret
<francois@fleuret.org>
Thu, 10 Sep 2020 11:23:16 +0000
(13:23 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 10 Sep 2020 11:23:16 +0000
(13:23 +0200)
flatparam.py
patch
|
blob
|
history
diff --git
a/flatparam.py
b/flatparam.py
index
3c20153
..
fbede34
100755
(executable)
--- a/
flatparam.py
+++ b/
flatparam.py
@@
-5,27
+5,46
@@
from torch import nn
######################################################################
######################################################################
+def _flatparam(model, whole, already = [], offset = 0):
+ for v in model._parameters:
+ p = model._parameters[v]
+ e = p.numel()
+ s = p.size()
+ model._parameters[v] = whole[offset:offset+e].view(s)
+ with torch.no_grad():
+ model._parameters[v].copy_(p)
+ offset += e
+ already.append(model)
+ for m in model.modules():
+ if m not in already:
+ offset = _flatparam(m, whole, already, offset)
+ return offset
+
def flatparam(model):
def flatparam(model):
- with torch.no_grad():
- n = sum(p.numel() for p in model.parameters())
- big = next(model.parameters()).new(n) # Get same device and dtype
- k = 0
- for p in model.parameters():
- tmp = p.new(0).set_(p)
- p.set_(big.storage(), k, p.size()).copy_(tmp)
- k += p.numel()
+ n = sum(p.numel() for p in model.parameters())
+ whole = next(model.parameters()).new(n) # Get same device and dtype
+ whole.requires_grad_()
+ _flatparam(model, whole, [], 0)
+ return whole
######################################################################
model = nn.Sequential(
######################################################################
model = nn.Sequential(
- nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
+ nn.Linear(2, 4),
+ nn.ReLU(),
+ nn.Sequential(
+ nn.Linear(4, 4),
+ nn.ReLU(), nn.Linear(4, 2)
+ )
)
)
+######################################################################
+
print('Before:')
for p in model.parameters():
print(p.size(), p.storage().size())
print('Before:')
for p in model.parameters():
print(p.size(), p.storage().size())
-flatparam(model)
+
whole =
flatparam(model)
print('After:')
for p in model.parameters():
print('After:')
for p in model.parameters():
@@
-37,7
+56,7
@@
print('Check:')
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD(
model.parameters()
, lr = 1e-2)
+optimizer = torch.optim.SGD(
[ whole ]
, lr = 1e-2)
mse = nn.MSELoss()
for e in range(10):
mse = nn.MSELoss()
for e in range(10):