4599f00e1fc4f36844e69a918d6b4ebeb6e883a3
[pytorch.git] / lazy_linear.py
1 #!/usr/bin/env python
2
3 from torch import nn, Tensor
4
5 ######################################################################
6
7 class LazyLinear(nn.Module):
8
9     def __init__(self, out_dim, bias = True):
10         super(LazyLinear, self).__init__()
11         self.out_dim = out_dim
12         self.bias = bias
13         self.core = None
14
15     def forward(self, x):
16         x = x.view(x.size(0), -1)
17
18         if self.core is None:
19             if self.training:
20                 self.core = nn.Linear(x.size(1), self.out_dim, self.bias)
21             else:
22                 raise RuntimeError('Undefined LazyLinear core in inference mode.')
23
24         return self.core(x)
25
26 ######################################################################
27
28 model = nn.Sequential(nn.Conv2d(1, 8, kernel_size = 5),
29                       nn.ReLU(inplace = True),
30                       LazyLinear(128),
31                       nn.ReLU(inplace = True),
32                       nn.Linear(128, 10))
33
34 # model.eval()
35
36 input = Tensor(100, 1, 32, 32).normal_()
37
38 output = model(input)