3 import torch, torchvision
6 ######################################################################
9 def _flatparam(model, whole, already=[], offset=0):
10 for v in model._parameters:
11 p = model._parameters[v]
14 model._parameters[v] = whole[offset : offset + e].view(s)
16 model._parameters[v].copy_(p)
19 for m in model.modules():
21 offset = _flatparam(m, whole, already, offset)
26 n = sum(p.numel() for p in model.parameters())
27 whole = next(model.parameters()).new(n) # Get same device and dtype
28 whole.requires_grad_()
29 _flatparam(model, whole)
30 model.parameters = lambda: iter([whole])
33 ######################################################################
35 model = nn.Sequential(
38 nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)),
41 ######################################################################
44 for p in model.parameters():
45 print(p.size(), p.storage().size())
50 for p in model.parameters():
51 print(p.size(), p.storage().size())
53 ######################################################################
57 input = torch.rand(100, 2)
58 targets = torch.rand(100, 2)
59 optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
64 loss = mse(output, targets)