Update.
[pytorch.git] / lazy_linear.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import nn, Tensor
9
10 ######################################################################
11
12
13 class LazyLinear(nn.Module):
14     def __init__(self, out_dim, bias=True):
15         super().__init__()
16         self.out_dim = out_dim
17         self.bias = bias
18         self.core = None
19
20     def forward(self, x):
21         x = x.view(x.size(0), -1)
22
23         if self.core is None:
24             if self.training:
25                 self.core = nn.Linear(x.size(1), self.out_dim, self.bias)
26             else:
27                 raise RuntimeError("Undefined LazyLinear core in inference mode.")
28
29         return self.core(x)
30
31     def named_parameters(self, memo=None, prefix=""):
32         assert self.core is not None, "Parameters not yet defined"
33         return super().named_parameters(memo, prefix)
34
35
36 ######################################################################
37
38 if __name__ == "__main__":
39     model = nn.Sequential(
40         nn.Conv2d(3, 8, kernel_size=5),
41         nn.ReLU(inplace=True),
42         LazyLinear(128),
43         nn.ReLU(inplace=True),
44         nn.Linear(128, 10),
45     )
46
47     # model.eval()
48
49     input = Tensor(100, 3, 32, 32).normal_()
50
51     output = model(input)
52
53     for n, x in model.named_parameters():
54         print(n, x.size())