Update
[pytorch] / lazy_linear.py
1 #!/usr/bin/env python
2
3 from torch import nn, Tensor
4
5 ######################################################################
6
7 class LazyLinear(nn.Module):
8
9     def __init__(self, out_dim, bias = True):
10         super(LazyLinear, self).__init__()
11         self.out_dim = out_dim
12         self.bias = bias
13         self.core = None
14
15     def forward(self, x):
16         x = x.view(x.size(0), -1)
17
18         if self.core is None:
19             if self.training:
20                 self.core = nn.Linear(x.size(1), self.out_dim, self.bias)
21             else:
22                 raise RuntimeError('Undefined LazyLinear core in inference mode.')
23
24         return self.core(x)
25
26     def named_parameters(self, memo=None, prefix=''):
27         assert self.core is not None, 'Parameters not yet defined'
28         return super(LazyLinear, self).named_parameters(memo, prefix)
29
30 ######################################################################
31
32 if __name__ == "__main__":
33     model = nn.Sequential(nn.Conv2d(3, 8, kernel_size = 5),
34                           nn.ReLU(inplace = True),
35                           LazyLinear(128),
36                           nn.ReLU(inplace = True),
37                           nn.Linear(128, 10))
38
39     # model.eval()
40
41     input = Tensor(100, 3, 32, 32).normal_()
42
43     output = model(input)
44
45     for n, x in model.named_parameters():
46         print(n, x.size())
47