Cosmetics + added the figures.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch, torchvision
13 from torch import nn
14 from torch.nn import functional as F
15
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
18 print(f'device {device}')
19
20 ######################################################################
21
22 def sample_gaussian_mixture(nb):
23     p, std = 0.3, 0.2
24     result = torch.randn(nb, 1) * std
25     result = result + torch.sign(torch.rand(result.size()) - p) / 2
26     return result
27
28 def sample_ramp(nb):
29     result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
30     return result
31
32 def sample_two_discs(nb):
33     a = torch.rand(nb) * math.pi * 2
34     b = torch.rand(nb).sqrt()
35     q = (torch.rand(nb) <= 0.5).long()
36     b = b * (0.3 + 0.2 * q)
37     result = torch.empty(nb, 2)
38     result[:, 0] = a.cos() * b - 0.5 + q
39     result[:, 1] = a.sin() * b - 0.5 + q
40     return result
41
42 def sample_disc_grid(nb):
43     a = torch.rand(nb) * math.pi * 2
44     b = torch.rand(nb).sqrt()
45     N = 4
46     q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
47     r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
48     b = b * 0.1
49     result = torch.empty(nb, 2)
50     result[:, 0] = a.cos() * b + q
51     result[:, 1] = a.sin() * b + r
52     return result
53
54 def sample_spiral(nb):
55     u = torch.rand(nb)
56     rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
57     theta = u * math.pi * 3
58     result = torch.empty(nb, 2)
59     result[:, 0] = theta.cos() * rho
60     result[:, 1] = theta.sin() * rho
61     return result
62
63 def sample_mnist(nb):
64     train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True)
65     result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
66     return result
67
68 samplers = {
69     'gaussian_mixture': sample_gaussian_mixture,
70     'ramp': sample_ramp,
71     'two_discs': sample_two_discs,
72     'disc_grid': sample_disc_grid,
73     'spiral': sample_spiral,
74     'mnist': sample_mnist,
75 }
76
77 ######################################################################
78
79 parser = argparse.ArgumentParser(
80     description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
81 "Denoising Diffusion Probabilistic Models" (2020)
82 https://arxiv.org/abs/2006.11239''',
83
84     formatter_class = argparse.ArgumentDefaultsHelpFormatter
85 )
86
87 parser.add_argument('--seed',
88                     type = int, default = 0,
89                     help = 'Random seed, < 0 is no seeding')
90
91 parser.add_argument('--nb_epochs',
92                     type = int, default = 100,
93                     help = 'How many epochs')
94
95 parser.add_argument('--batch_size',
96                     type = int, default = 25,
97                     help = 'Batch size')
98
99 parser.add_argument('--nb_samples',
100                     type = int, default = 25000,
101                     help = 'Number of training examples')
102
103 parser.add_argument('--learning_rate',
104                     type = float, default = 1e-3,
105                     help = 'Learning rate')
106
107 parser.add_argument('--ema_decay',
108                     type = float, default = 0.9999,
109                     help = 'EMA decay, <= 0 is no EMA')
110
111 data_list = ', '.join( [ str(k) for k in samplers ])
112
113 parser.add_argument('--data',
114                     type = str, default = 'gaussian_mixture',
115                     help = f'Toy data-set to use: {data_list}')
116
117 parser.add_argument('--no_window',
118                     action='store_true', default = False)
119
120 args = parser.parse_args()
121
122 if args.seed >= 0:
123     # torch.backends.cudnn.deterministic = True
124     # torch.backends.cudnn.benchmark = False
125     # torch.use_deterministic_algorithms(True)
126     torch.manual_seed(args.seed)
127     if torch.cuda.is_available():
128         torch.cuda.manual_seed_all(args.seed)
129
130 ######################################################################
131
132 class EMA:
133     def __init__(self, model, decay):
134         self.model = model
135         self.decay = decay
136         self.mem = { }
137         with torch.no_grad():
138             for p in model.parameters():
139                 self.mem[p] = p.clone()
140
141     def step(self):
142         with torch.no_grad():
143             for p in self.model.parameters():
144                 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
145
146     def copy_to_model(self):
147         with torch.no_grad():
148             for p in self.model.parameters():
149                 p.copy_(self.mem[p])
150
151 ######################################################################
152
153 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
154 # additional dimension / channel
155
156 class TimeAppender(nn.Module):
157     def __init__(self):
158         super().__init__()
159
160     def forward(self, u):
161         x, t = u
162         if not torch.is_tensor(t):
163             t = x.new_full((x.size(0),), t)
164         t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
165         return torch.cat((x, t), 1)
166
167 class ConvNet(nn.Module):
168     def __init__(self, in_channels, out_channels):
169         super().__init__()
170
171         ks, nc = 5, 64
172
173         self.core = nn.Sequential(
174             TimeAppender(),
175             nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
176             nn.ReLU(),
177             nn.Conv2d(nc, nc, ks, padding = ks//2),
178             nn.ReLU(),
179             nn.Conv2d(nc, nc, ks, padding = ks//2),
180             nn.ReLU(),
181             nn.Conv2d(nc, nc, ks, padding = ks//2),
182             nn.ReLU(),
183             nn.Conv2d(nc, nc, ks, padding = ks//2),
184             nn.ReLU(),
185             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
186         )
187
188     def forward(self, u):
189         return self.core(u)
190
191 ######################################################################
192 # Data
193
194 try:
195     train_input = samplers[args.data](args.nb_samples).to(device)
196 except KeyError:
197     print(f'unknown data {args.data}')
198     exit(1)
199
200 train_mean, train_std = train_input.mean(), train_input.std()
201
202 ######################################################################
203 # Model
204
205 if train_input.dim() == 2:
206     nh = 256
207
208     model = nn.Sequential(
209         TimeAppender(),
210         nn.Linear(train_input.size(1) + 1, nh),
211         nn.ReLU(),
212         nn.Linear(nh, nh),
213         nn.ReLU(),
214         nn.Linear(nh, nh),
215         nn.ReLU(),
216         nn.Linear(nh, train_input.size(1)),
217     )
218
219 elif train_input.dim() == 4:
220
221     model = ConvNet(train_input.size(1), train_input.size(1))
222
223 model.to(device)
224
225 print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
226
227 ######################################################################
228 # Generate
229
230 def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
231
232     with torch.no_grad():
233
234         x = torch.randn(size, device = device)
235
236         for t in range(T-1, -1, -1):
237             output = model((x, t / (T - 1) - 0.5))
238             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
239             x = 1/torch.sqrt(alpha[t]) \
240                 * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
241                 + sigma[t] * z
242
243         x = x * train_std + train_mean
244
245         return x
246
247 ######################################################################
248 # Train
249
250 T = 1000
251 beta = torch.linspace(1e-4, 0.02, T, device = device)
252 alpha = 1 - beta
253 alpha_bar = alpha.log().cumsum(0).exp()
254 sigma = beta.sqrt()
255
256 ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
257
258 for k in range(args.nb_epochs):
259
260     acc_loss = 0
261     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
262
263     for x0 in train_input.split(args.batch_size):
264         x0 = (x0 - train_mean) / train_std
265         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
266         eps = torch.randn_like(x0)
267         xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
268         output = model((xt, t / (T - 1) - 0.5))
269         loss = (eps - output).pow(2).mean()
270         acc_loss += loss.item() * x0.size(0)
271
272         optimizer.zero_grad()
273         loss.backward()
274         optimizer.step()
275
276         if ema is not None: ema.step()
277
278     print(f'{k} {acc_loss / train_input.size(0)}')
279
280 if ema is not None: ema.copy_to_model()
281
282 ######################################################################
283 # Plot
284
285 model.eval()
286
287 ########################################
288 # Nx1 -> histogram
289 if train_input.dim() == 2 and train_input.size(1) == 1:
290
291     fig = plt.figure()
292     fig.set_figheight(5)
293     fig.set_figwidth(8)
294
295     ax = fig.add_subplot(1, 1, 1)
296
297     x = generate((10000, 1), T, alpha, alpha_bar, sigma,
298                  model, train_mean, train_std)
299
300     ax.set_xlim(-1.25, 1.25)
301     ax.spines.right.set_visible(False)
302     ax.spines.top.set_visible(False)
303
304     d = train_input.flatten().detach().to('cpu').numpy()
305     ax.hist(d, 25, (-1, 1),
306             density = True,
307             histtype = 'bar', edgecolor = 'white', color = 'lightblue', label = 'Train')
308
309     d = x.flatten().detach().to('cpu').numpy()
310     ax.hist(d, 25, (-1, 1),
311             density = True,
312             histtype = 'step', color = 'red', label = 'Synthesis')
313
314     ax.legend(frameon = False, loc = 2)
315
316     filename = f'minidiffusion_{args.data}.pdf'
317     print(f'saving {filename}')
318     fig.savefig(filename, bbox_inches='tight')
319
320     if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
321         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
322         plt.show()
323
324 ########################################
325 # Nx2 -> scatter plot
326 elif train_input.dim() == 2 and train_input.size(1) == 2:
327
328     fig = plt.figure()
329     fig.set_figheight(6)
330     fig.set_figwidth(6)
331
332     ax = fig.add_subplot(1, 1, 1)
333
334     x = generate((1000, 2), T, alpha, alpha_bar, sigma,
335                  model, train_mean, train_std)
336
337     ax.set_xlim(-1.5, 1.5)
338     ax.set_ylim(-1.5, 1.5)
339     ax.set(aspect = 1)
340     ax.spines.right.set_visible(False)
341     ax.spines.top.set_visible(False)
342
343     d = train_input[:x.size(0)].detach().to('cpu').numpy()
344     ax.scatter(d[:, 0], d[:, 1],
345                s = 2.5, color = 'gray', label = 'Train')
346
347     d = x.detach().to('cpu').numpy()
348     ax.scatter(d[:, 0], d[:, 1],
349                s = 2.0, color = 'red', label = 'Synthesis')
350
351     ax.legend(frameon = False, loc = 2)
352
353     filename = f'minidiffusion_{args.data}.pdf'
354     print(f'saving {filename}')
355     fig.savefig(filename, bbox_inches='tight')
356
357     if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
358         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
359         plt.show()
360
361 ########################################
362 # NxCxHxW -> image
363 elif train_input.dim() == 4:
364
365     x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
366                  model, train_mean, train_std)
367
368     x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
369                                     nrow = 16, padding = 1, pad_value = 64)
370     x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
371
372     t = torchvision.utils.make_grid(train_input[:128],
373                                     nrow = 16, padding = 1, pad_value = 64)
374     t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
375
376     result = 1 - torch.cat((t, x), 2) / 255
377
378     filename = f'minidiffusion_{args.data}.png'
379     print(f'saving {filename}')
380     torchvision.utils.save_image(result, filename)
381
382 else:
383
384     print(f'cannot plot result of size {train_input.size()}')
385
386 ######################################################################