Update.
[pytorch.git] / flatparam.py
1 #!/usr/bin/env python
2
3 import torch, torchvision
4 from torch import nn
5
6 ######################################################################
7
8 def _flatparam(model, whole, already = [], offset = 0):
9     for v in model._parameters:
10         p = model._parameters[v]
11         e = p.numel()
12         s = p.size()
13         model._parameters[v] = whole[offset:offset+e].view(s)
14         with torch.no_grad():
15             model._parameters[v].copy_(p)
16         offset += e
17     already.append(model)
18     for m in model.modules():
19         if m not in already:
20             offset = _flatparam(m, whole, already, offset)
21     return offset
22
23 def flatparam(model):
24     n = sum(p.numel() for p in model.parameters())
25     whole = next(model.parameters()).new(n) # Get same device and dtype
26     whole.requires_grad_()
27     _flatparam(model, whole)
28     model.parameters = lambda: iter([ whole ])
29
30 ######################################################################
31
32 model = nn.Sequential(
33     nn.Linear(2, 4),
34     nn.ReLU(),
35     nn.Sequential(
36         nn.Linear(4, 4),
37         nn.ReLU(),
38         nn.Linear(4, 2)
39     )
40 )
41
42 ######################################################################
43
44 print('Before:')
45 for p in model.parameters():
46     print(p.size(), p.storage().size())
47
48 flatparam(model)
49
50 print('After:')
51 for p in model.parameters():
52     print(p.size(), p.storage().size())
53
54 ######################################################################
55
56 print('Check:')
57
58 input = torch.rand(100, 2)
59 targets = torch.rand(100, 2)
60 optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
61 mse = nn.MSELoss()
62
63 for e in range(10):
64     output = model(input)
65     loss = mse(output, targets)
66     print(e, loss.item())
67     optimizer.zero_grad()
68     loss.backward()
69     optimizer.step()