Update.
[pytorch.git] / tiny_vae.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
7 # @XREMOTE_GET: *.png
8
9 # Any copyright is dedicated to the Public Domain.
10 # https://creativecommons.org/publicdomain/zero/1.0/
11
12 # Written by Francois Fleuret <francois@fleuret.org>
13
14 import sys, os, argparse, time, math, itertools
15
16 import torch, torchvision
17
18 from torch import optim, nn
19 from torch.nn import functional as F
20
21 ######################################################################
22
23 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(
28     description="Very simple implementation of a VAE for teaching."
29 )
30
31 parser.add_argument("--nb_epochs", type=int, default=25)
32
33 parser.add_argument("--learning_rate", type=float, default=1e-3)
34
35 parser.add_argument("--batch_size", type=int, default=100)
36
37 parser.add_argument("--data_dir", type=str, default="./data/")
38
39 parser.add_argument("--log_filename", type=str, default="train.log")
40
41 parser.add_argument("--latent_dim", type=int, default=32)
42
43 parser.add_argument("--nb_channels", type=int, default=128)
44
45 parser.add_argument("--no_dkl", action="store_true")
46
47 # With that option, do not follow the setup of the original VAE paper
48 # of forcing the variance of X|Z to 1 during training and to 0 for
49 # sampling, but optimize and use the variance.
50 parser.add_argument("--no_hacks", action="store_true")
51
52 args = parser.parse_args()
53
54 log_file = open(args.log_filename, "w")
55
56 ######################################################################
57
58
59 def log_string(s):
60     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
61
62     if log_file is not None:
63         log_file.write(t + s + "\n")
64         log_file.flush()
65
66     print(t + s)
67     sys.stdout.flush()
68
69
70 ######################################################################
71
72
73 def sample_gaussian(mu, log_var):
74     std = log_var.mul(0.5).exp()
75     return torch.randn(mu.size(), device=mu.device) * std + mu
76
77
78 def log_p_gaussian(x, mu, log_var):
79     var = log_var.exp()
80     return (
81         (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
82         .flatten(1)
83         .sum(1)
84     )
85
86
87 def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
88     mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
89     mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)
90     var_a = log_var_a.exp()
91     var_b = log_var_b.exp()
92     return 0.5 * (
93         log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
94     ).sum(1)
95
96
97 ######################################################################
98
99
100 class LatentGivenImageNet(nn.Module):
101     def __init__(self, nb_channels, latent_dim):
102         super().__init__()
103
104         self.model = nn.Sequential(
105             nn.Conv2d(1, nb_channels, kernel_size=1),  # to 28x28
106             nn.ReLU(inplace=True),
107             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 24x24
108             nn.ReLU(inplace=True),
109             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
110             nn.ReLU(inplace=True),
111             nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
112             nn.ReLU(inplace=True),
113             nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
114             nn.ReLU(inplace=True),
115             nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
116         )
117
118     def forward(self, x):
119         output = self.model(x).view(x.size(0), 2, -1)
120         mu, log_var = output[:, 0], output[:, 1]
121         return mu, log_var
122
123
124 class ImageGivenLatentNet(nn.Module):
125     def __init__(self, nb_channels, latent_dim):
126         super().__init__()
127
128         self.model = nn.Sequential(
129             nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
130             nn.ReLU(inplace=True),
131             nn.ConvTranspose2d(
132                 nb_channels, nb_channels, kernel_size=3, stride=2
133             ),  # from 4x4
134             nn.ReLU(inplace=True),
135             nn.ConvTranspose2d(
136                 nb_channels, nb_channels, kernel_size=4, stride=2
137             ),  # from 9x9
138             nn.ReLU(inplace=True),
139             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
140             nn.ReLU(inplace=True),
141             nn.ConvTranspose2d(nb_channels, 2, kernel_size=5),  # from 24x24
142         )
143
144     def forward(self, z):
145         output = self.model(z.view(z.size(0), -1, 1, 1))
146         mu, log_var = output[:, 0:1], output[:, 1:2]
147         if not args.no_hacks:
148             log_var[...] = 0
149         return mu, log_var
150
151
152 ######################################################################
153
154 data_dir = os.path.join(args.data_dir, "mnist")
155
156 train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True)
157 train_input = train_set.data.view(-1, 1, 28, 28).float()
158
159 test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True)
160 test_input = test_set.data.view(-1, 1, 28, 28).float()
161
162 ######################################################################
163
164 model_q_Z_given_x = LatentGivenImageNet(
165     nb_channels=args.nb_channels, latent_dim=args.latent_dim
166 )
167
168 model_p_X_given_z = ImageGivenLatentNet(
169     nb_channels=args.nb_channels, latent_dim=args.latent_dim
170 )
171
172 optimizer = optim.Adam(
173     itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
174     lr=args.learning_rate,
175 )
176
177 model_p_X_given_z.to(device)
178 model_q_Z_given_x.to(device)
179
180 ######################################################################
181
182 train_input, test_input = train_input.to(device), test_input.to(device)
183
184 train_mu, train_std = train_input.mean(), train_input.std()
185 train_input.sub_(train_mu).div_(train_std)
186 test_input.sub_(train_mu).div_(train_std)
187
188 ######################################################################
189
190 mean_p_Z = train_input.new_zeros(1, args.latent_dim)
191 log_var_p_Z = mean_p_Z
192
193 for epoch in range(args.nb_epochs):
194     acc_loss = 0
195
196     for x in train_input.split(args.batch_size):
197         mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
198         z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
199         mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
200
201         if args.no_dkl:
202             log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x)
203             log_p_x_z = log_p_gaussian(
204                 x, mean_p_X_given_z, log_var_p_X_given_z
205             ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z)
206             loss = -(log_p_x_z - log_q_z_given_x).mean()
207         else:
208             log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z)
209             dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
210                 mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z
211             )
212             loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
213
214         optimizer.zero_grad()
215         loss.backward()
216         optimizer.step()
217
218         acc_loss += loss.item() * x.size(0)
219
220     log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
221
222 ######################################################################
223
224
225 def save_image(x, filename):
226     x = x * train_std + train_mu
227     x = x.clamp(min=0, max=255) / 255
228     torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
229
230
231 # Save a bunch of test images
232
233 x = test_input[:256]
234 save_image(x, "input.png")
235
236 # Save the same images after encoding / decoding
237
238 mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
239 z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
240 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
241 if args.no_hacks:
242     x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
243 else:
244     x = mean_p_X_given_z
245 save_image(x, "output.png")
246
247 # Generate a bunch of images
248
249 z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
250 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
251 if args.no_hacks:
252     x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
253 else:
254     x = mean_p_X_given_z
255 save_image(x, "synth.png")
256
257 ######################################################################