Initial commit.
[pytorch.git] / flatparam.py
1 #!/usr/bin/env python
2
3 import torch, torchvision
4 from torch import nn
5
6 ######################################################################
7
8 def flatparam(model):
9     with torch.no_grad():
10         n = sum(p.numel() for p in model.parameters())
11         big = next(model.parameters()).new(n) # Get same device and dtype
12         k = 0
13         for p in model.parameters():
14             tmp = p.new(0).set_(p)
15             p.set_(big.storage(), k, p.size()).copy_(tmp)
16             k += p.numel()
17
18 ######################################################################
19
20 model = nn.Sequential(
21     nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
22 )
23
24 print('Before:')
25 for p in model.parameters():
26     print(p.size(), p.storage().size())
27
28 flatparam(model)
29
30 print('After:')
31 for p in model.parameters():
32     print(p.size(), p.storage().size())
33
34 ######################################################################
35
36 print('Check:')
37
38 input = torch.rand(100, 2)
39 targets = torch.rand(100, 2)
40 optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
41 mse = nn.MSELoss()
42
43 for e in range(10):
44     output = model(input)
45     loss = mse(output, targets)
46     print(e, loss.item())
47     optimizer.zero_grad()
48     loss.backward()
49     optimizer.step()