Update.
[pytorch.git] / tinyae.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import sys, argparse, time
9
10 import torch, torchvision
11
12 from torch import optim, nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
19 ######################################################################
20
21 parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
22
23 parser.add_argument("--nb_epochs", type=int, default=25)
24
25 parser.add_argument("--batch_size", type=int, default=100)
26
27 parser.add_argument("--data_dir", type=str, default="./data/")
28
29 parser.add_argument("--log_filename", type=str, default="train.log")
30
31 parser.add_argument("--embedding_dim", type=int, default=8)
32
33 parser.add_argument("--nb_channels", type=int, default=32)
34
35 args = parser.parse_args()
36
37 log_file = open(args.log_filename, "w")
38
39 ######################################################################
40
41
42 def log_string(s):
43     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
44
45     if log_file is not None:
46         log_file.write(t + s + "\n")
47         log_file.flush()
48
49     print(t + s)
50     sys.stdout.flush()
51
52
53 ######################################################################
54
55
56 class AutoEncoder(nn.Module):
57     def __init__(self, nb_channels, embedding_dim):
58         super().__init__()
59
60         self.encoder = nn.Sequential(
61             nn.Conv2d(1, nb_channels, kernel_size=5),  # to 24x24
62             nn.ReLU(inplace=True),
63             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
64             nn.ReLU(inplace=True),
65             nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
66             nn.ReLU(inplace=True),
67             nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
68             nn.ReLU(inplace=True),
69             nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
70         )
71
72         self.decoder = nn.Sequential(
73             nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
74             nn.ReLU(inplace=True),
75             nn.ConvTranspose2d(
76                 nb_channels, nb_channels, kernel_size=3, stride=2
77             ),  # from 4x4
78             nn.ReLU(inplace=True),
79             nn.ConvTranspose2d(
80                 nb_channels, nb_channels, kernel_size=4, stride=2
81             ),  # from 9x9
82             nn.ReLU(inplace=True),
83             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
84             nn.ReLU(inplace=True),
85             nn.ConvTranspose2d(nb_channels, 1, kernel_size=5),  # from 24x24
86         )
87
88     def encode(self, x):
89         return self.encoder(x).view(x.size(0), -1)
90
91     def decode(self, z):
92         return self.decoder(z.view(z.size(0), -1, 1, 1))
93
94     def forward(self, x):
95         x = self.encoder(x)
96         x = self.decoder(x)
97         return x
98
99
100 ######################################################################
101
102 train_set = torchvision.datasets.MNIST(
103     args.data_dir + "/mnist/", train=True, download=True
104 )
105 train_input = train_set.data.view(-1, 1, 28, 28).float()
106
107 test_set = torchvision.datasets.MNIST(
108     args.data_dir + "/mnist/", train=False, download=True
109 )
110 test_input = test_set.data.view(-1, 1, 28, 28).float()
111
112 ######################################################################
113
114 model = AutoEncoder(args.nb_channels, args.embedding_dim)
115 optimizer = optim.Adam(model.parameters(), lr=1e-3)
116
117 model.to(device)
118
119 train_input, test_input = train_input.to(device), test_input.to(device)
120
121 mu, std = train_input.mean(), train_input.std()
122 train_input.sub_(mu).div_(std)
123 test_input.sub_(mu).div_(std)
124
125 ######################################################################
126
127 for epoch in range(args.nb_epochs):
128     acc_loss = 0
129
130     for input in train_input.split(args.batch_size):
131         output = model(input)
132         loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
133
134         optimizer.zero_grad()
135         loss.backward()
136         optimizer.step()
137
138         acc_loss += loss.item()
139
140     log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
141
142 ######################################################################
143
144 input = test_input[:256]
145
146 # Encode / decode
147
148 z = model.encode(input)
149 output = model.decode(z)
150
151 torchvision.utils.save_image(1 - input, "ae-input.png", nrow=16, pad_value=0.8)
152 torchvision.utils.save_image(1 - output, "ae-output.png", nrow=16, pad_value=0.8)
153
154 # Dumb synthesis
155
156 z = model.encode(input)
157 mu, std = z.mean(0), z.std(0)
158 z = z.normal_() * std + mu
159 output = model.decode(z)
160
161 torchvision.utils.save_image(1 - output, "ae-synth.png", nrow=16, pad_value=0.8)
162
163 ######################################################################