Niiice.
[pytorch.git] / flatparam.py
1 #!/usr/bin/env python
2
3 import torch, torchvision
4 from torch import nn
5
6 ######################################################################
7
8 def _flatparam(model, whole, already = [], offset = 0):
9     for v in model._parameters:
10         p = model._parameters[v]
11         e = p.numel()
12         s = p.size()
13         model._parameters[v] = whole[offset:offset+e].view(s)
14         with torch.no_grad():
15             model._parameters[v].copy_(p)
16         offset += e
17     already.append(model)
18     for m in model.modules():
19         if m not in already:
20             offset = _flatparam(m, whole, already, offset)
21     return offset
22
23 def flatparam(model):
24     n = sum(p.numel() for p in model.parameters())
25     whole = next(model.parameters()).new(n) # Get same device and dtype
26     whole.requires_grad_()
27     _flatparam(model, whole)
28     model.parameters = lambda: iter([ whole ])
29
30 ######################################################################
31
32 model = nn.Sequential(
33     nn.Linear(2, 4),
34     nn.ReLU(),
35     nn.Sequential(
36         nn.Linear(4, 4),
37         nn.ReLU(), nn.Linear(4, 2)
38     )
39 )
40
41 ######################################################################
42
43 print('Before:')
44 for p in model.parameters():
45     print(p.size(), p.storage().size())
46
47 flatparam(model)
48
49 print('After:')
50 for p in model.parameters():
51     print(p.size(), p.storage().size())
52
53 ######################################################################
54
55 print('Check:')
56
57 input = torch.rand(100, 2)
58 targets = torch.rand(100, 2)
59 optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
60 mse = nn.MSELoss()
61
62 for e in range(10):
63     output = model(input)
64     loss = mse(output, targets)
65     print(e, loss.item())
66     optimizer.zero_grad()
67     loss.backward()
68     optimizer.step()