projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
flatparam.py
diff --git
a/flatparam.py
b/flatparam.py
index
fbede34
..
57a8720
100755
(executable)
--- a/
flatparam.py
+++ b/
flatparam.py
@@
-24,8
+24,8
@@
def flatparam(model):
n = sum(p.numel() for p in model.parameters())
whole = next(model.parameters()).new(n) # Get same device and dtype
whole.requires_grad_()
n = sum(p.numel() for p in model.parameters())
whole = next(model.parameters()).new(n) # Get same device and dtype
whole.requires_grad_()
- _flatparam(model, whole
, [], 0
)
- return whole
+ _flatparam(model, whole)
+ model.parameters = lambda: iter([ whole ])
######################################################################
######################################################################
@@
-34,7
+34,8
@@
model = nn.Sequential(
nn.ReLU(),
nn.Sequential(
nn.Linear(4, 4),
nn.ReLU(),
nn.Sequential(
nn.Linear(4, 4),
- nn.ReLU(), nn.Linear(4, 2)
+ nn.ReLU(),
+ nn.Linear(4, 2)
)
)
)
)
@@
-44,7
+45,7
@@
print('Before:')
for p in model.parameters():
print(p.size(), p.storage().size())
for p in model.parameters():
print(p.size(), p.storage().size())
-
whole =
flatparam(model)
+flatparam(model)
print('After:')
for p in model.parameters():
print('After:')
for p in model.parameters():
@@
-56,7
+57,7
@@
print('Check:')
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD(
[ whole ]
, lr = 1e-2)
+optimizer = torch.optim.SGD(
model.parameters()
, lr = 1e-2)
mse = nn.MSELoss()
for e in range(10):
mse = nn.MSELoss()
for e in range(10):