Update.
[pytorch.git] / flatparam.py
1 #!/usr/bin/env python
2
3 import torch, torchvision
4 from torch import nn
5
6 ######################################################################
7
8
9 def _flatparam(model, whole, already=[], offset=0):
10     for v in model._parameters:
11         p = model._parameters[v]
12         e = p.numel()
13         s = p.size()
14         model._parameters[v] = whole[offset : offset + e].view(s)
15         with torch.no_grad():
16             model._parameters[v].copy_(p)
17         offset += e
18     already.append(model)
19     for m in model.modules():
20         if m not in already:
21             offset = _flatparam(m, whole, already, offset)
22     return offset
23
24
25 def flatparam(model):
26     n = sum(p.numel() for p in model.parameters())
27     whole = next(model.parameters()).new(n)  # Get same device and dtype
28     whole.requires_grad_()
29     _flatparam(model, whole)
30     model.parameters = lambda: iter([whole])
31
32
33 ######################################################################
34
35 model = nn.Sequential(
36     nn.Linear(2, 4),
37     nn.ReLU(),
38     nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)),
39 )
40
41 ######################################################################
42
43 print("Before:")
44 for p in model.parameters():
45     print(p.size(), p.storage().size())
46
47 flatparam(model)
48
49 print("After:")
50 for p in model.parameters():
51     print(p.size(), p.storage().size())
52
53 ######################################################################
54
55 print("Check:")
56
57 input = torch.rand(100, 2)
58 targets = torch.rand(100, 2)
59 optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
60 mse = nn.MSELoss()
61
62 for e in range(10):
63     output = model(input)
64     loss = mse(output, targets)
65     print(e, loss.item())
66     optimizer.zero_grad()
67     loss.backward()
68     optimizer.step()