######################################################################
-def _flatparam(model, whole, already = [], offset = 0):
+
+def _flatparam(model, whole, already=[], offset=0):
for v in model._parameters:
p = model._parameters[v]
e = p.numel()
s = p.size()
- model._parameters[v] = whole[offset:offset+e].view(s)
+ model._parameters[v] = whole[offset : offset + e].view(s)
with torch.no_grad():
model._parameters[v].copy_(p)
offset += e
offset = _flatparam(m, whole, already, offset)
return offset
+
def flatparam(model):
n = sum(p.numel() for p in model.parameters())
- whole = next(model.parameters()).new(n) # Get same device and dtype
+ whole = next(model.parameters()).new(n) # Get same device and dtype
whole.requires_grad_()
- _flatparam(model, whole, [], 0)
- return whole
+ _flatparam(model, whole)
+ model.parameters = lambda: iter([whole])
+
######################################################################
model = nn.Sequential(
nn.Linear(2, 4),
nn.ReLU(),
- nn.Sequential(
- nn.Linear(4, 4),
- nn.ReLU(), nn.Linear(4, 2)
- )
+ nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)),
)
######################################################################
-print('Before:')
+print("Before:")
for p in model.parameters():
print(p.size(), p.storage().size())
-whole = flatparam(model)
+flatparam(model)
-print('After:')
+print("After:")
for p in model.parameters():
print(p.size(), p.storage().size())
######################################################################
-print('Check:')
+print("Check:")
input = torch.rand(100, 2)
targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD([ whole ], lr = 1e-2)
+optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
mse = nn.MSELoss()
for e in range(10):