3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
9 # Any copyright is dedicated to the Public Domain.
10 # https://creativecommons.org/publicdomain/zero/1.0/
12 # Written by Francois Fleuret <francois@fleuret.org>
14 import sys, os, argparse, time, math, itertools
16 import torch, torchvision
18 from torch import optim, nn
19 from torch.nn import functional as F
21 ######################################################################
23 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 ######################################################################
27 parser = argparse.ArgumentParser(
28 description="Very simple implementation of a VAE for teaching."
31 parser.add_argument("--nb_epochs", type=int, default=100)
33 parser.add_argument("--learning_rate", type=float, default=2e-4)
35 parser.add_argument("--batch_size", type=int, default=100)
37 parser.add_argument("--data_dir", type=str, default="./data/")
39 parser.add_argument("--log_filename", type=str, default="train.log")
41 parser.add_argument("--latent_dim", type=int, default=32)
43 parser.add_argument("--nb_channels", type=int, default=128)
45 parser.add_argument("--no_dkl", action="store_true")
47 args = parser.parse_args()
49 log_file = open(args.log_filename, "w")
51 ######################################################################
55 t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
57 if log_file is not None:
58 log_file.write(t + s + "\n")
65 ######################################################################
68 def sample_gaussian(param):
70 std = log_var.mul(0.5).exp()
71 return torch.randn(mu.size(), device=mu.device) * std + mu
74 def log_p_gaussian(x, param):
78 (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
84 def dkl_gaussians(param_a, param_b):
85 mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1)
86 mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1)
87 var_a = log_var_a.exp()
88 var_b = log_var_b.exp()
90 log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
94 ######################################################################
97 class LatentGivenImageNet(nn.Module):
98 def __init__(self, nb_channels, latent_dim):
101 self.model = nn.Sequential(
102 nn.Conv2d(1, nb_channels, kernel_size=1), # to 28x28
103 nn.ReLU(inplace=True),
104 nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 24x24
105 nn.ReLU(inplace=True),
106 nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20
107 nn.ReLU(inplace=True),
108 nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9
109 nn.ReLU(inplace=True),
110 nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4
111 nn.ReLU(inplace=True),
112 nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
115 def forward(self, x):
116 output = self.model(x).view(x.size(0), 2, -1)
117 mu, log_var = output[:, 0], output[:, 1]
121 class ImageGivenLatentNet(nn.Module):
122 def __init__(self, nb_channels, latent_dim):
125 self.model = nn.Sequential(
126 nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
127 nn.ReLU(inplace=True),
129 nb_channels, nb_channels, kernel_size=3, stride=2
131 nn.ReLU(inplace=True),
133 nb_channels, nb_channels, kernel_size=4, stride=2
135 nn.ReLU(inplace=True),
136 nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20
137 nn.ReLU(inplace=True),
138 nn.ConvTranspose2d(nb_channels, 2, kernel_size=5), # from 24x24
141 def forward(self, z):
142 output = self.model(z.view(z.size(0), -1, 1, 1))
143 mu, log_var = output[:, 0:1], output[:, 1:2]
144 # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1]
148 ######################################################################
150 data_dir = os.path.join(args.data_dir, "mnist")
152 train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True)
153 train_input = train_set.data.view(-1, 1, 28, 28).float()
155 test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True)
156 test_input = test_set.data.view(-1, 1, 28, 28).float()
158 ######################################################################
160 model_q_Z_given_x = LatentGivenImageNet(
161 nb_channels=args.nb_channels, latent_dim=args.latent_dim
164 model_p_X_given_z = ImageGivenLatentNet(
165 nb_channels=args.nb_channels, latent_dim=args.latent_dim
168 optimizer = optim.Adam(
169 itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
170 lr=args.learning_rate,
173 model_p_X_given_z.to(device)
174 model_q_Z_given_x.to(device)
176 ######################################################################
178 train_input, test_input = train_input.to(device), test_input.to(device)
180 train_mu, train_std = train_input.mean(), train_input.std()
181 train_input.sub_(train_mu).div_(train_std)
182 test_input.sub_(train_mu).div_(train_std)
184 ######################################################################
186 zeros = train_input.new_zeros(1, args.latent_dim)
188 param_p_Z = zeros, zeros
190 for epoch in range(args.nb_epochs):
193 for x in train_input.split(args.batch_size):
194 param_q_Z_given_x = model_q_Z_given_x(x)
195 z = sample_gaussian(param_q_Z_given_x)
196 param_p_X_given_z = model_p_X_given_z(z)
197 log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z)
200 log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x)
201 log_p_z = log_p_gaussian(z, param_p_Z)
202 log_p_x_z = log_p_x_given_z + log_p_x_z
203 loss = -(log_p_x_z - log_q_z_given_x).mean()
205 dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
206 loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
208 optimizer.zero_grad()
212 acc_loss += loss.item() * x.size(0)
214 log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
216 ######################################################################
219 def save_image(x, filename):
220 x = x * train_std + train_mu
221 x = x.clamp(min=0, max=255) / 255
222 torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
225 # Save a bunch of test images
228 save_image(x, "input.png")
230 # Save the same images after encoding / decoding
232 param_q_Z_given_x = model_q_Z_given_x(x)
233 z = sample_gaussian(param_q_Z_given_x)
234 param_p_X_given_z = model_p_X_given_z(z)
235 x = sample_gaussian(param_p_X_given_z)
236 save_image(x, "output.png")
237 save_image(param_p_X_given_z[0], "output_mean.png")
239 # Generate a bunch of images
242 (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
244 param_p_X_given_z = model_p_X_given_z(z)
245 x = sample_gaussian(param_p_X_given_z)
246 save_image(x, "synth.png")
247 save_image(param_p_X_given_z[0], "synth_mean.png")
249 ######################################################################