Update.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch, torchvision
13 from torch import nn
14 from torch.nn import functional as F
15
16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
18 print(f"device {device}")
19
20 ######################################################################
21
22
23 def sample_gaussian_mixture(nb):
24     p, std = 0.3, 0.2
25     result = torch.randn(nb, 1) * std
26     result = result + torch.sign(torch.rand(result.size()) - p) / 2
27     return result
28
29
30 def sample_ramp(nb):
31     result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
32     return result
33
34
35 def sample_two_discs(nb):
36     a = torch.rand(nb) * math.pi * 2
37     b = torch.rand(nb).sqrt()
38     q = (torch.rand(nb) <= 0.5).long()
39     b = b * (0.3 + 0.2 * q)
40     result = torch.empty(nb, 2)
41     result[:, 0] = a.cos() * b - 0.5 + q
42     result[:, 1] = a.sin() * b - 0.5 + q
43     return result
44
45
46 def sample_disc_grid(nb):
47     a = torch.rand(nb) * math.pi * 2
48     b = torch.rand(nb).sqrt()
49     N = 4
50     q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
51     r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
52     b = b * 0.1
53     result = torch.empty(nb, 2)
54     result[:, 0] = a.cos() * b + q
55     result[:, 1] = a.sin() * b + r
56     return result
57
58
59 def sample_spiral(nb):
60     u = torch.rand(nb)
61     rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
62     theta = u * math.pi * 3
63     result = torch.empty(nb, 2)
64     result[:, 0] = theta.cos() * rho
65     result[:, 1] = theta.sin() * rho
66     return result
67
68
69 def sample_mnist(nb):
70     train_set = torchvision.datasets.MNIST(root="./data/", train=True, download=True)
71     result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
72     return result
73
74
75 samplers = {
76     f.__name__.removeprefix("sample_"): f
77     for f in [
78         sample_gaussian_mixture,
79         sample_ramp,
80         sample_two_discs,
81         sample_disc_grid,
82         sample_spiral,
83         sample_mnist,
84     ]
85 }
86
87 ######################################################################
88
89 parser = argparse.ArgumentParser(
90     description="""A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
91 "Denoising Diffusion Probabilistic Models" (2020)
92 https://arxiv.org/abs/2006.11239""",
93     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
94 )
95
96 parser.add_argument(
97     "--seed", type=int, default=0, help="Random seed, < 0 is no seeding"
98 )
99
100 parser.add_argument("--nb_epochs", type=int, default=100, help="How many epochs")
101
102 parser.add_argument("--batch_size", type=int, default=25, help="Batch size")
103
104 parser.add_argument(
105     "--nb_samples", type=int, default=25000, help="Number of training examples"
106 )
107
108 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
109
110 parser.add_argument(
111     "--ema_decay", type=float, default=0.9999, help="EMA decay, <= 0 is no EMA"
112 )
113
114 data_list = ", ".join([str(k) for k in samplers])
115
116 parser.add_argument(
117     "--data",
118     type=str,
119     default="gaussian_mixture",
120     help=f"Toy data-set to use: {data_list}",
121 )
122
123 parser.add_argument("--no_window", action="store_true", default=False)
124
125 args = parser.parse_args()
126
127 if args.seed >= 0:
128     # torch.backends.cudnn.deterministic = True
129     # torch.backends.cudnn.benchmark = False
130     # torch.use_deterministic_algorithms(True)
131     torch.manual_seed(args.seed)
132     if torch.cuda.is_available():
133         torch.cuda.manual_seed_all(args.seed)
134
135 ######################################################################
136
137
138 class EMA:
139     def __init__(self, model, decay):
140         self.model = model
141         self.decay = decay
142         self.mem = {}
143         with torch.no_grad():
144             for p in model.parameters():
145                 self.mem[p] = p.clone()
146
147     def step(self):
148         with torch.no_grad():
149             for p in self.model.parameters():
150                 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
151
152     def copy_to_model(self):
153         with torch.no_grad():
154             for p in self.model.parameters():
155                 p.copy_(self.mem[p])
156
157
158 ######################################################################
159
160 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
161 # additional dimension / channel
162
163
164 class TimeAppender(nn.Module):
165     def __init__(self):
166         super().__init__()
167
168     def forward(self, u):
169         x, t = u
170         if not torch.is_tensor(t):
171             t = x.new_full((x.size(0),), t)
172         t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:, :1])
173         return torch.cat((x, t), 1)
174
175
176 class ConvNet(nn.Module):
177     def __init__(self, in_channels, out_channels):
178         super().__init__()
179
180         ks, nc = 5, 64
181
182         self.core = nn.Sequential(
183             TimeAppender(),
184             nn.Conv2d(in_channels + 1, nc, ks, padding=ks // 2),
185             nn.ReLU(),
186             nn.Conv2d(nc, nc, ks, padding=ks // 2),
187             nn.ReLU(),
188             nn.Conv2d(nc, nc, ks, padding=ks // 2),
189             nn.ReLU(),
190             nn.Conv2d(nc, nc, ks, padding=ks // 2),
191             nn.ReLU(),
192             nn.Conv2d(nc, nc, ks, padding=ks // 2),
193             nn.ReLU(),
194             nn.Conv2d(nc, out_channels, ks, padding=ks // 2),
195         )
196
197     def forward(self, u):
198         return self.core(u)
199
200
201 ######################################################################
202 # Data
203
204 try:
205     train_input = samplers[args.data](args.nb_samples).to(device)
206 except KeyError:
207     print(f"unknown data {args.data}")
208     exit(1)
209
210 train_mean, train_std = train_input.mean(), train_input.std()
211
212 ######################################################################
213 # Model
214
215 if train_input.dim() == 2:
216     nh = 256
217
218     model = nn.Sequential(
219         TimeAppender(),
220         nn.Linear(train_input.size(1) + 1, nh),
221         nn.ReLU(),
222         nn.Linear(nh, nh),
223         nn.ReLU(),
224         nn.Linear(nh, nh),
225         nn.ReLU(),
226         nn.Linear(nh, train_input.size(1)),
227     )
228
229 elif train_input.dim() == 4:
230     model = ConvNet(train_input.size(1), train_input.size(1))
231
232 model.to(device)
233
234 print(f"nb_parameters {sum([ p.numel() for p in model.parameters() ])}")
235
236 ######################################################################
237 # Generate
238
239
240 def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
241     with torch.no_grad():
242         x = torch.randn(size, device=device)
243
244         for t in range(T - 1, -1, -1):
245             output = model((x, t / (T - 1) - 0.5))
246             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
247             x = (
248                 1
249                 / torch.sqrt(alpha[t])
250                 * (x - (1 - alpha[t]) / torch.sqrt(1 - alpha_bar[t]) * output)
251                 + sigma[t] * z
252             )
253
254         x = x * train_std + train_mean
255
256         return x
257
258
259 ######################################################################
260 # Train
261
262 T = 1000
263 beta = torch.linspace(1e-4, 0.02, T, device=device)
264 alpha = 1 - beta
265 alpha_bar = alpha.log().cumsum(0).exp()
266 sigma = beta.sqrt()
267
268 ema = EMA(model, decay=args.ema_decay) if args.ema_decay > 0 else None
269
270 for k in range(args.nb_epochs):
271     acc_loss = 0
272     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
273
274     for x0 in train_input.split(args.batch_size):
275         x0 = (x0 - train_mean) / train_std
276         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device=x0.device)
277         eps = torch.randn_like(x0)
278         xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
279         output = model((xt, t / (T - 1) - 0.5))
280         loss = (eps - output).pow(2).mean()
281         acc_loss += loss.item() * x0.size(0)
282
283         optimizer.zero_grad()
284         loss.backward()
285         optimizer.step()
286
287         if ema is not None:
288             ema.step()
289
290     print(f"{k} {acc_loss / train_input.size(0)}")
291
292 if ema is not None:
293     ema.copy_to_model()
294
295 ######################################################################
296 # Plot
297
298 model.eval()
299
300 ########################################
301 # Nx1 -> histogram
302 if train_input.dim() == 2 and train_input.size(1) == 1:
303     fig = plt.figure()
304     fig.set_figheight(5)
305     fig.set_figwidth(8)
306
307     ax = fig.add_subplot(1, 1, 1)
308
309     x = generate((10000, 1), T, alpha, alpha_bar, sigma, model, train_mean, train_std)
310
311     ax.set_xlim(-1.25, 1.25)
312     ax.spines.right.set_visible(False)
313     ax.spines.top.set_visible(False)
314
315     d = train_input.flatten().detach().to("cpu").numpy()
316     ax.hist(
317         d,
318         25,
319         (-1, 1),
320         density=True,
321         histtype="bar",
322         edgecolor="white",
323         color="lightblue",
324         label="Train",
325     )
326
327     d = x.flatten().detach().to("cpu").numpy()
328     ax.hist(
329         d, 25, (-1, 1), density=True, histtype="step", color="red", label="Synthesis"
330     )
331
332     ax.legend(frameon=False, loc=2)
333
334     filename = f"minidiffusion_{args.data}.pdf"
335     print(f"saving {filename}")
336     fig.savefig(filename, bbox_inches="tight")
337
338     if not args.no_window and hasattr(plt.get_current_fig_manager(), "window"):
339         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
340         plt.show()
341
342 ########################################
343 # Nx2 -> scatter plot
344 elif train_input.dim() == 2 and train_input.size(1) == 2:
345     fig = plt.figure()
346     fig.set_figheight(6)
347     fig.set_figwidth(6)
348
349     ax = fig.add_subplot(1, 1, 1)
350
351     x = generate((1000, 2), T, alpha, alpha_bar, sigma, model, train_mean, train_std)
352
353     ax.set_xlim(-1.5, 1.5)
354     ax.set_ylim(-1.5, 1.5)
355     ax.set(aspect=1)
356     ax.spines.right.set_visible(False)
357     ax.spines.top.set_visible(False)
358
359     d = train_input[: x.size(0)].detach().to("cpu").numpy()
360     ax.scatter(d[:, 0], d[:, 1], s=2.5, color="gray", label="Train")
361
362     d = x.detach().to("cpu").numpy()
363     ax.scatter(d[:, 0], d[:, 1], s=2.0, color="red", label="Synthesis")
364
365     ax.legend(frameon=False, loc=2)
366
367     filename = f"minidiffusion_{args.data}.pdf"
368     print(f"saving {filename}")
369     fig.savefig(filename, bbox_inches="tight")
370
371     if not args.no_window and hasattr(plt.get_current_fig_manager(), "window"):
372         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
373         plt.show()
374
375 ########################################
376 # NxCxHxW -> image
377 elif train_input.dim() == 4:
378     x = generate(
379         (128,) + train_input.size()[1:],
380         T,
381         alpha,
382         alpha_bar,
383         sigma,
384         model,
385         train_mean,
386         train_std,
387     )
388
389     x = torchvision.utils.make_grid(
390         x.clamp(min=0, max=255), nrow=16, padding=1, pad_value=64
391     )
392     x = F.pad(x, pad=(2, 2, 2, 2), value=64)[None]
393
394     t = torchvision.utils.make_grid(train_input[:128], nrow=16, padding=1, pad_value=64)
395     t = F.pad(t, pad=(2, 2, 2, 2), value=64)[None]
396
397     result = 1 - torch.cat((t, x), 2) / 255
398
399     filename = f"minidiffusion_{args.data}.png"
400     print(f"saving {filename}")
401     torchvision.utils.save_image(result, filename)
402
403 else:
404     print(f"cannot plot result of size {train_input.size()}")
405
406 ######################################################################