Update.
[pytorch.git] / tiny_vae.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
7 # @XREMOTE_GET: *.png
8
9 # Any copyright is dedicated to the Public Domain.
10 # https://creativecommons.org/publicdomain/zero/1.0/
11
12 # Written by Francois Fleuret <francois@fleuret.org>
13
14 import sys, os, argparse, time, math, itertools
15
16 import torch, torchvision
17
18 from torch import optim, nn
19 from torch.nn import functional as F
20
21 ######################################################################
22
23 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(
28     description="Very simple implementation of a VAE for teaching."
29 )
30
31 parser.add_argument("--nb_epochs", type=int, default=100)
32
33 parser.add_argument("--learning_rate", type=float, default=1e-4)
34
35 parser.add_argument("--batch_size", type=int, default=100)
36
37 parser.add_argument("--data_dir", type=str, default="./data/")
38
39 parser.add_argument("--log_filename", type=str, default="train.log")
40
41 parser.add_argument("--latent_dim", type=int, default=32)
42
43 parser.add_argument("--nb_channels", type=int, default=64)
44
45 parser.add_argument("--no_dkl", action="store_true")
46
47 parser.add_argument("--beta", type=float, default=1.0)
48
49 args = parser.parse_args()
50
51 log_file = open(args.log_filename, "w")
52
53 ######################################################################
54
55
56 def log_string(s):
57     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
58
59     if log_file is not None:
60         log_file.write(t + s + "\n")
61         log_file.flush()
62
63     print(t + s)
64     sys.stdout.flush()
65
66
67 ######################################################################
68
69
70 def sample_gaussian(param):
71     mu, log_var = param
72     std = log_var.mul(0.5).exp()
73     return torch.randn(mu.size(), device=mu.device) * std + mu
74
75
76 def log_p_gaussian(x, param):
77     mu, log_var = param
78     var = log_var.exp()
79     return (
80         (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
81         .flatten(1)
82         .sum(1)
83     )
84
85
86 def dkl_gaussians(param_a, param_b):
87     mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1)
88     mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1)
89     var_a = log_var_a.exp()
90     var_b = log_var_b.exp()
91     return 0.5 * (
92         log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
93     ).sum(1)
94
95
96 ######################################################################
97
98
99 class LatentGivenImageNet(nn.Module):
100     def __init__(self, nb_channels, latent_dim):
101         super().__init__()
102
103         self.model = nn.Sequential(
104             nn.Conv2d(1, nb_channels, kernel_size=1),  # to 28x28
105             nn.ReLU(inplace=True),
106             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 24x24
107             nn.ReLU(inplace=True),
108             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
109             nn.ReLU(inplace=True),
110             nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
111             nn.ReLU(inplace=True),
112             nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
113             nn.ReLU(inplace=True),
114             nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
115         )
116
117     def forward(self, x):
118         output = self.model(x).view(x.size(0), 2, -1)
119         mu, log_var = output[:, 0], output[:, 1]
120         return mu, log_var
121
122
123 class ImageGivenLatentNet(nn.Module):
124     def __init__(self, nb_channels, latent_dim):
125         super().__init__()
126
127         self.model = nn.Sequential(
128             nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
129             nn.ReLU(inplace=True),
130             nn.ConvTranspose2d(
131                 nb_channels, nb_channels, kernel_size=3, stride=2
132             ),  # from 4x4
133             nn.ReLU(inplace=True),
134             nn.ConvTranspose2d(
135                 nb_channels, nb_channels, kernel_size=4, stride=2
136             ),  # from 9x9
137             nn.ReLU(inplace=True),
138             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
139             nn.ReLU(inplace=True),
140             nn.ConvTranspose2d(nb_channels, 2, kernel_size=5),  # from 24x24
141         )
142
143     def forward(self, z):
144         output = self.model(z.view(z.size(0), -1, 1, 1))
145         mu, log_var = output[:, 0:1], output[:, 1:2]
146         # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1]
147         return mu, log_var
148
149
150 ######################################################################
151
152 data_dir = os.path.join(args.data_dir, "mnist")
153
154 train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True)
155 train_input = train_set.data.view(-1, 1, 28, 28).float()
156
157 test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True)
158 test_input = test_set.data.view(-1, 1, 28, 28).float()
159
160 ######################################################################
161
162
163 def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""):
164     def save_image(x, filename):
165         x = x * train_std + train_mu
166         x = x.clamp(min=0, max=255) / 255
167         torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
168         log_string(f"wrote {filename}")
169
170     # Save a bunch of train images
171
172     x = train_input[:256]
173     save_image(x, f"{prefix}train_input.png")
174
175     # Save the same images after encoding / decoding
176
177     param_q_Z_given_x = model_q_Z_given_x(x)
178     z = sample_gaussian(param_q_Z_given_x)
179     param_p_X_given_z = model_p_X_given_z(z)
180     x = sample_gaussian(param_p_X_given_z)
181     save_image(x, f"{prefix}train_output.png")
182     save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
183
184     # Save a bunch of test images
185
186     x = test_input[:256]
187     save_image(x, f"{prefix}input.png")
188
189     # Save the same images after encoding / decoding
190
191     param_q_Z_given_x = model_q_Z_given_x(x)
192     z = sample_gaussian(param_q_Z_given_x)
193     param_p_X_given_z = model_p_X_given_z(z)
194     x = sample_gaussian(param_p_X_given_z)
195     save_image(x, f"{prefix}output.png")
196     save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
197
198     # Generate a bunch of images
199
200     z = sample_gaussian(
201         (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
202     )
203     param_p_X_given_z = model_p_X_given_z(z)
204     x = sample_gaussian(param_p_X_given_z)
205     save_image(x, f"{prefix}synth.png")
206     save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
207
208
209 ######################################################################
210
211 model_q_Z_given_x = LatentGivenImageNet(
212     nb_channels=args.nb_channels, latent_dim=args.latent_dim
213 )
214
215 model_p_X_given_z = ImageGivenLatentNet(
216     nb_channels=args.nb_channels, latent_dim=args.latent_dim
217 )
218
219 optimizer = optim.Adam(
220     itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
221     lr=args.learning_rate,
222 )
223
224 model_p_X_given_z.to(device)
225 model_q_Z_given_x.to(device)
226
227 ######################################################################
228
229 train_input, test_input = train_input.to(device), test_input.to(device)
230
231 train_mu, train_std = train_input.mean(), train_input.std()
232 train_input.sub_(train_mu).div_(train_std)
233 test_input.sub_(train_mu).div_(train_std)
234
235 ######################################################################
236
237 zeros = train_input.new_zeros(1, args.latent_dim)
238
239 param_p_Z = zeros, zeros
240
241 for n_epoch in range(args.nb_epochs):
242     acc_loss = 0
243
244     for x in train_input.split(args.batch_size):
245         param_q_Z_given_x = model_q_Z_given_x(x)
246         z = sample_gaussian(param_q_Z_given_x)
247         param_p_X_given_z = model_p_X_given_z(z)
248         log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z)
249
250         if args.no_dkl:
251             log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x)
252             log_p_z = log_p_gaussian(z, param_p_Z)
253             log_p_x_z = log_p_x_given_z + log_p_x_z
254             loss = -(log_p_x_z - log_q_z_given_x).mean()
255         else:
256             dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
257             loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean()
258
259         optimizer.zero_grad()
260         loss.backward()
261         optimizer.step()
262
263         acc_loss += loss.item() * x.size(0)
264
265     log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}")
266
267     if (n_epoch + 1) % 25 == 0:
268         save_images(model_q_Z_given_x, model_p_X_given_z, f"epoch_{n_epoch+1:04d}_")
269
270 ######################################################################