Update.
[pytorch.git] / tiny_vae.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
7 # @XREMOTE_GET: *.png
8
9 # Any copyright is dedicated to the Public Domain.
10 # https://creativecommons.org/publicdomain/zero/1.0/
11
12 # Written by Francois Fleuret <francois@fleuret.org>
13
14 import sys, os, argparse, time, math, itertools
15
16 import torch, torchvision
17
18 from torch import optim, nn
19 from torch.nn import functional as F
20
21 ######################################################################
22
23 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(
28     description="Very simple implementation of a VAE for teaching."
29 )
30
31 parser.add_argument("--nb_epochs", type=int, default=100)
32
33 parser.add_argument("--learning_rate", type=float, default=2e-4)
34
35 parser.add_argument("--batch_size", type=int, default=100)
36
37 parser.add_argument("--data_dir", type=str, default="./data/")
38
39 parser.add_argument("--log_filename", type=str, default="train.log")
40
41 parser.add_argument("--latent_dim", type=int, default=32)
42
43 parser.add_argument("--nb_channels", type=int, default=128)
44
45 parser.add_argument("--no_dkl", action="store_true")
46
47 args = parser.parse_args()
48
49 log_file = open(args.log_filename, "w")
50
51 ######################################################################
52
53
54 def log_string(s):
55     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
56
57     if log_file is not None:
58         log_file.write(t + s + "\n")
59         log_file.flush()
60
61     print(t + s)
62     sys.stdout.flush()
63
64
65 ######################################################################
66
67
68 def sample_gaussian(mu, log_var):
69     std = log_var.mul(0.5).exp()
70     return torch.randn(mu.size(), device=mu.device) * std + mu
71
72
73 def log_p_gaussian(x, mu, log_var):
74     var = log_var.exp()
75     return (
76         (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
77         .flatten(1)
78         .sum(1)
79     )
80
81
82 def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
83     mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
84     mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)
85     var_a = log_var_a.exp()
86     var_b = log_var_b.exp()
87     return 0.5 * (
88         log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
89     ).sum(1)
90
91
92 ######################################################################
93
94
95 class LatentGivenImageNet(nn.Module):
96     def __init__(self, nb_channels, latent_dim):
97         super().__init__()
98
99         self.model = nn.Sequential(
100             nn.Conv2d(1, nb_channels, kernel_size=1),  # to 28x28
101             nn.ReLU(inplace=True),
102             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 24x24
103             nn.ReLU(inplace=True),
104             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
105             nn.ReLU(inplace=True),
106             nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
107             nn.ReLU(inplace=True),
108             nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
109             nn.ReLU(inplace=True),
110             nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
111         )
112
113     def forward(self, x):
114         output = self.model(x).view(x.size(0), 2, -1)
115         mu, log_var = output[:, 0], output[:, 1]
116         return mu, log_var
117
118
119 class ImageGivenLatentNet(nn.Module):
120     def __init__(self, nb_channels, latent_dim):
121         super().__init__()
122
123         self.model = nn.Sequential(
124             nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
125             nn.ReLU(inplace=True),
126             nn.ConvTranspose2d(
127                 nb_channels, nb_channels, kernel_size=3, stride=2
128             ),  # from 4x4
129             nn.ReLU(inplace=True),
130             nn.ConvTranspose2d(
131                 nb_channels, nb_channels, kernel_size=4, stride=2
132             ),  # from 9x9
133             nn.ReLU(inplace=True),
134             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
135             nn.ReLU(inplace=True),
136             nn.ConvTranspose2d(nb_channels, 2, kernel_size=5),  # from 24x24
137         )
138
139     def forward(self, z):
140         output = self.model(z.view(z.size(0), -1, 1, 1))
141         mu, log_var = output[:, 0:1], output[:, 1:2]
142         # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
143         return mu, log_var
144
145
146 ######################################################################
147
148 data_dir = os.path.join(args.data_dir, "mnist")
149
150 train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True)
151 train_input = train_set.data.view(-1, 1, 28, 28).float()
152
153 test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True)
154 test_input = test_set.data.view(-1, 1, 28, 28).float()
155
156 ######################################################################
157
158 model_q_Z_given_x = LatentGivenImageNet(
159     nb_channels=args.nb_channels, latent_dim=args.latent_dim
160 )
161
162 model_p_X_given_z = ImageGivenLatentNet(
163     nb_channels=args.nb_channels, latent_dim=args.latent_dim
164 )
165
166 optimizer = optim.Adam(
167     itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
168     lr=args.learning_rate,
169 )
170
171 model_p_X_given_z.to(device)
172 model_q_Z_given_x.to(device)
173
174 ######################################################################
175
176 train_input, test_input = train_input.to(device), test_input.to(device)
177
178 train_mu, train_std = train_input.mean(), train_input.std()
179 train_input.sub_(train_mu).div_(train_std)
180 test_input.sub_(train_mu).div_(train_std)
181
182 ######################################################################
183
184 mean_p_Z = train_input.new_zeros(1, args.latent_dim)
185 log_var_p_Z = mean_p_Z
186
187 for epoch in range(args.nb_epochs):
188     acc_loss = 0
189
190     for x in train_input.split(args.batch_size):
191         mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
192         z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
193         mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
194
195         if args.no_dkl:
196             log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x)
197             log_p_x_z = log_p_gaussian(
198                 x, mean_p_X_given_z, log_var_p_X_given_z
199             ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z)
200             loss = -(log_p_x_z - log_q_z_given_x).mean()
201         else:
202             log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z)
203             dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
204                 mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z
205             )
206             loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
207
208         optimizer.zero_grad()
209         loss.backward()
210         optimizer.step()
211
212         acc_loss += loss.item() * x.size(0)
213
214     log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
215
216 ######################################################################
217
218
219 def save_image(x, filename):
220     x = x * train_std + train_mu
221     x = x.clamp(min=0, max=255) / 255
222     torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
223
224
225 # Save a bunch of test images
226
227 x = test_input[:256]
228 save_image(x, "input.png")
229
230 # Save the same images after encoding / decoding
231
232 mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
233 z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
234 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
235 x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
236 save_image(x, "output.png")
237
238 # Generate a bunch of images
239
240 z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
241 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
242 x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
243 save_image(x, "synth.png")
244
245 ######################################################################