Tried to make the source clearer, added the TimeAppender Module.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch, torchvision
13 from torch import nn
14
15 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
17 print(f'device {device}')
18
19 ######################################################################
20
21 def sample_gaussian_mixture(nb):
22     p, std = 0.3, 0.2
23     result = torch.randn(nb, 1) * std
24     result = result + torch.sign(torch.rand(result.size()) - p) / 2
25     return result
26
27 def sample_ramp(nb):
28     result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
29     return result
30
31 def sample_two_discs(nb):
32     a = torch.rand(nb) * math.pi * 2
33     b = torch.rand(nb).sqrt()
34     q = (torch.rand(nb) <= 0.5).long()
35     b = b * (0.3 + 0.2 * q)
36     result = torch.empty(nb, 2)
37     result[:, 0] = a.cos() * b - 0.5 + q
38     result[:, 1] = a.sin() * b - 0.5 + q
39     return result
40
41 def sample_disc_grid(nb):
42     a = torch.rand(nb) * math.pi * 2
43     b = torch.rand(nb).sqrt()
44     N = 4
45     q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
46     r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
47     b = b * 0.1
48     result = torch.empty(nb, 2)
49     result[:, 0] = a.cos() * b + q
50     result[:, 1] = a.sin() * b + r
51     return result
52
53 def sample_spiral(nb):
54     u = torch.rand(nb)
55     rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
56     theta = u * math.pi * 3
57     result = torch.empty(nb, 2)
58     result[:, 0] = theta.cos() * rho
59     result[:, 1] = theta.sin() * rho
60     return result
61
62 def sample_mnist(nb):
63     train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True)
64     result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
65     return result
66
67 samplers = {
68     'gaussian_mixture': sample_gaussian_mixture,
69     'ramp': sample_ramp,
70     'two_discs': sample_two_discs,
71     'disc_grid': sample_disc_grid,
72     'spiral': sample_spiral,
73     'mnist': sample_mnist,
74 }
75
76 ######################################################################
77
78 parser = argparse.ArgumentParser(
79     description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
80 "Denoising Diffusion Probabilistic Models" (2020)
81 https://arxiv.org/abs/2006.11239''',
82
83     formatter_class = argparse.ArgumentDefaultsHelpFormatter
84 )
85
86 parser.add_argument('--seed',
87                     type = int, default = 0,
88                     help = 'Random seed, < 0 is no seeding')
89
90 parser.add_argument('--nb_epochs',
91                     type = int, default = 100,
92                     help = 'How many epochs')
93
94 parser.add_argument('--batch_size',
95                     type = int, default = 25,
96                     help = 'Batch size')
97
98 parser.add_argument('--nb_samples',
99                     type = int, default = 25000,
100                     help = 'Number of training examples')
101
102 parser.add_argument('--learning_rate',
103                     type = float, default = 1e-3,
104                     help = 'Learning rate')
105
106 parser.add_argument('--ema_decay',
107                     type = float, default = 0.9999,
108                     help = 'EMA decay, <= 0 is no EMA')
109
110 data_list = ', '.join( [ str(k) for k in samplers ])
111
112 parser.add_argument('--data',
113                     type = str, default = 'gaussian_mixture',
114                     help = f'Toy data-set to use: {data_list}')
115
116 args = parser.parse_args()
117
118 if args.seed >= 0:
119     # torch.backends.cudnn.deterministic = True
120     # torch.backends.cudnn.benchmark = False
121     # torch.use_deterministic_algorithms(True)
122     torch.manual_seed(args.seed)
123     if torch.cuda.is_available():
124         torch.cuda.manual_seed_all(args.seed)
125
126 ######################################################################
127
128 class EMA:
129     def __init__(self, model, decay):
130         self.model = model
131         self.decay = decay
132         self.mem = { }
133         with torch.no_grad():
134             for p in model.parameters():
135                 self.mem[p] = p.clone()
136
137     def step(self):
138         with torch.no_grad():
139             for p in self.model.parameters():
140                 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
141
142     def copy_to_model(self):
143         with torch.no_grad():
144             for p in self.model.parameters():
145                 p.copy_(self.mem[p])
146
147 ######################################################################
148
149 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
150 # additional dimension / channel
151
152 class TimeAppender(nn.Module):
153     def __init__(self):
154         super().__init__()
155
156     def forward(self, u):
157         x, t = u
158         if not torch.is_tensor(t):
159             t = x.new_full((x.size(0),), t)
160         t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
161         return torch.cat((x, t), 1)
162
163 class ConvNet(nn.Module):
164     def __init__(self, in_channels, out_channels):
165         super().__init__()
166
167         ks, nc = 5, 64
168
169         self.core = nn.Sequential(
170             TimeAppender(),
171             nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
172             nn.ReLU(),
173             nn.Conv2d(nc, nc, ks, padding = ks//2),
174             nn.ReLU(),
175             nn.Conv2d(nc, nc, ks, padding = ks//2),
176             nn.ReLU(),
177             nn.Conv2d(nc, nc, ks, padding = ks//2),
178             nn.ReLU(),
179             nn.Conv2d(nc, nc, ks, padding = ks//2),
180             nn.ReLU(),
181             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
182         )
183
184     def forward(self, u):
185         return self.core(u)
186
187 ######################################################################
188 # Data
189
190 try:
191     train_input = samplers[args.data](args.nb_samples).to(device)
192 except KeyError:
193     print(f'unknown data {args.data}')
194     exit(1)
195
196 train_mean, train_std = train_input.mean(), train_input.std()
197
198 ######################################################################
199 # Model
200
201 if train_input.dim() == 2:
202     nh = 256
203
204     model = nn.Sequential(
205         TimeAppender(),
206         nn.Linear(train_input.size(1) + 1, nh),
207         nn.ReLU(),
208         nn.Linear(nh, nh),
209         nn.ReLU(),
210         nn.Linear(nh, nh),
211         nn.ReLU(),
212         nn.Linear(nh, train_input.size(1)),
213     )
214
215 elif train_input.dim() == 4:
216
217     model = ConvNet(train_input.size(1), train_input.size(1))
218
219 model.to(device)
220
221 print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
222
223 ######################################################################
224 # Generate
225
226 def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
227
228     with torch.no_grad():
229
230         x = torch.randn(size, device = device)
231
232         for t in range(T-1, -1, -1):
233             output = model((x, t / (T - 1) - 0.5))
234             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
235             x = 1/torch.sqrt(alpha[t]) \
236                 * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
237                 + sigma[t] * z
238
239         x = x * train_std + train_mean
240
241         return x
242
243 ######################################################################
244 # Train
245
246 T = 1000
247 beta = torch.linspace(1e-4, 0.02, T, device = device)
248 alpha = 1 - beta
249 alpha_bar = alpha.log().cumsum(0).exp()
250 sigma = beta.sqrt()
251
252 ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
253
254 for k in range(args.nb_epochs):
255
256     acc_loss = 0
257     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
258
259     for x0 in train_input.split(args.batch_size):
260         x0 = (x0 - train_mean) / train_std
261         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
262         eps = torch.randn_like(x0)
263         xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
264         output = model((xt, t / (T - 1) - 0.5))
265         loss = (eps - output).pow(2).mean()
266         acc_loss += loss.item() * x0.size(0)
267
268         optimizer.zero_grad()
269         loss.backward()
270         optimizer.step()
271
272         if ema is not None: ema.step()
273
274     print(f'{k} {acc_loss / train_input.size(0)}')
275
276 if ema is not None: ema.copy_to_model()
277
278 ######################################################################
279 # Plot
280
281 model.eval()
282
283 if train_input.dim() == 2:
284
285     fig = plt.figure()
286     ax = fig.add_subplot(1, 1, 1)
287
288     # Nx1 -> histogram
289     if train_input.size(1) == 1:
290
291         x = generate((10000, 1), alpha, alpha_bar, sigma,
292                      model, train_mean, train_std)
293
294         ax.set_xlim(-1.25, 1.25)
295         ax.spines.right.set_visible(False)
296         ax.spines.top.set_visible(False)
297
298         d = train_input.flatten().detach().to('cpu').numpy()
299         ax.hist(d, 25, (-1, 1),
300                 density = True,
301                 histtype = 'stepfilled', color = 'lightblue', label = 'Train')
302
303         d = x.flatten().detach().to('cpu').numpy()
304         ax.hist(d, 25, (-1, 1),
305                 density = True,
306                 histtype = 'step', color = 'red', label = 'Synthesis')
307
308         ax.legend(frameon = False, loc = 2)
309
310     # Nx2 -> scatter plot
311     elif train_input.size(1) == 2:
312
313         x = generate((1000, 2), alpha, alpha_bar, sigma,
314                      model, train_mean, train_std)
315
316         ax.set_xlim(-1.5, 1.5)
317         ax.set_ylim(-1.5, 1.5)
318         ax.set(aspect = 1)
319         ax.spines.right.set_visible(False)
320         ax.spines.top.set_visible(False)
321
322         d = x.detach().to('cpu').numpy()
323         ax.scatter(d[:, 0], d[:, 1],
324                    s = 2.0, color = 'red', label = 'Synthesis')
325
326         d = train_input[:x.size(0)].detach().to('cpu').numpy()
327         ax.scatter(d[:, 0], d[:, 1],
328                    s = 2.0, color = 'gray', label = 'Train')
329
330         ax.legend(frameon = False, loc = 2)
331
332     filename = f'diffusion_{args.data}.pdf'
333     print(f'saving {filename}')
334     fig.savefig(filename, bbox_inches='tight')
335
336     if hasattr(plt.get_current_fig_manager(), 'window'):
337         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
338         plt.show()
339
340 # NxCxHxW -> image
341 elif train_input.dim() == 4:
342
343     x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma,
344                  model, train_mean, train_std)
345     x = 1 - x.clamp(min = 0, max = 255) / 255
346
347     filename = f'diffusion_{args.data}.png'
348     print(f'saving {filename}')
349     torchvision.utils.save_image(x, filename, nrow = 16, pad_value = 0.8)
350
351 ######################################################################