Update.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch, torchvision
13 from torch import nn
14
15 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
17 ######################################################################
18
19 def sample_gaussian_mixture(nb):
20     p, std = 0.3, 0.2
21     result = torch.empty(nb, 1).normal_(0, std)
22     result = result + torch.sign(torch.rand(result.size()) - p) / 2
23     return result
24
25 def sample_two_discs(nb):
26     a = torch.rand(nb) * math.pi * 2
27     b = torch.rand(nb).sqrt()
28     q = (torch.rand(nb) <= 0.5).long()
29     b = b * (0.3 + 0.2 * q)
30     result = torch.empty(nb, 2)
31     result[:, 0] = a.cos() * b - 0.5 + q
32     result[:, 1] = a.sin() * b - 0.5 + q
33     return result
34
35 def sample_disc_grid(nb):
36     a = torch.rand(nb) * math.pi * 2
37     b = torch.rand(nb).sqrt()
38     q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
39     r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
40     b = b * 0.1
41     result = torch.empty(nb, 2)
42     result[:, 0] = a.cos() * b + q
43     result[:, 1] = a.sin() * b + r
44     return result
45
46 def sample_spiral(nb):
47     u = torch.rand(nb)
48     rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
49     theta = u * math.pi * 3
50     result = torch.empty(nb, 2)
51     result[:, 0] = theta.cos() * rho
52     result[:, 1] = theta.sin() * rho
53     return result
54
55 def sample_mnist(nb):
56     train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True)
57     result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
58     return result
59
60 samplers = {
61     'gaussian_mixture': sample_gaussian_mixture,
62     'two_discs': sample_two_discs,
63     'disc_grid': sample_disc_grid,
64     'spiral': sample_spiral,
65     'mnist': sample_mnist,
66 }
67
68 ######################################################################
69
70 parser = argparse.ArgumentParser(
71     description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
72 "Denoising Diffusion Probabilistic Models" (2020)
73 https://arxiv.org/abs/2006.11239''',
74
75     formatter_class = argparse.ArgumentDefaultsHelpFormatter
76 )
77
78 parser.add_argument('--seed',
79                     type = int, default = 0,
80                     help = 'Random seed, < 0 is no seeding')
81
82 parser.add_argument('--nb_epochs',
83                     type = int, default = 100,
84                     help = 'How many epochs')
85
86 parser.add_argument('--batch_size',
87                     type = int, default = 25,
88                     help = 'Batch size')
89
90 parser.add_argument('--nb_samples',
91                     type = int, default = 25000,
92                     help = 'Number of training examples')
93
94 parser.add_argument('--learning_rate',
95                     type = float, default = 1e-3,
96                     help = 'Learning rate')
97
98 parser.add_argument('--ema_decay',
99                     type = float, default = 0.9999,
100                     help = 'EMA decay, < 0 is no EMA')
101
102 data_list = ', '.join( [ str(k) for k in samplers ])
103
104 parser.add_argument('--data',
105                     type = str, default = 'gaussian_mixture',
106                     help = f'Toy data-set to use: {data_list}')
107
108 args = parser.parse_args()
109
110 if args.seed >= 0:
111     # torch.backends.cudnn.deterministic = True
112     # torch.backends.cudnn.benchmark = False
113     # torch.use_deterministic_algorithms(True)
114     torch.manual_seed(args.seed)
115     if torch.cuda.is_available():
116         torch.cuda.manual_seed_all(args.seed)
117
118 ######################################################################
119
120 class EMA:
121     def __init__(self, model, decay):
122         self.model = model
123         self.decay = decay
124         if self.decay < 0: return
125         self.ema = { }
126         with torch.no_grad():
127             for p in model.parameters():
128                 self.ema[p] = p.clone()
129
130     def step(self):
131         if self.decay < 0: return
132         with torch.no_grad():
133             for p in self.model.parameters():
134                 self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
135
136     def copy(self):
137         if self.decay < 0: return
138         with torch.no_grad():
139             for p in self.model.parameters():
140                 p.copy_(self.ema[p])
141
142 ######################################################################
143
144 class ConvNet(nn.Module):
145     def __init__(self, in_channels, out_channels):
146         super().__init__()
147
148         ks, nc = 5, 64
149
150         self.core = nn.Sequential(
151             nn.Conv2d(in_channels, nc, ks, padding = ks//2),
152             nn.ReLU(),
153             nn.Conv2d(nc, nc, ks, padding = ks//2),
154             nn.ReLU(),
155             nn.Conv2d(nc, nc, ks, padding = ks//2),
156             nn.ReLU(),
157             nn.Conv2d(nc, nc, ks, padding = ks//2),
158             nn.ReLU(),
159             nn.Conv2d(nc, nc, ks, padding = ks//2),
160             nn.ReLU(),
161             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
162         )
163
164     def forward(self, x):
165         return self.core(x)
166
167 ######################################################################
168 # Data
169
170 try:
171     train_input = samplers[args.data](args.nb_samples).to(device)
172 except KeyError:
173     print(f'unknown data {args.data}')
174     exit(1)
175
176 train_mean, train_std = train_input.mean(), train_input.std()
177
178 ######################################################################
179 # Model
180
181 if train_input.dim() == 2:
182     nh = 64
183
184     model = nn.Sequential(
185         nn.Linear(train_input.size(1) + 1, nh),
186         nn.ReLU(),
187         nn.Linear(nh, nh),
188         nn.ReLU(),
189         nn.Linear(nh, nh),
190         nn.ReLU(),
191         nn.Linear(nh, train_input.size(1)),
192     )
193
194 elif train_input.dim() == 4:
195
196     model = ConvNet(train_input.size(1) + 1, train_input.size(1))
197
198 model.to(device)
199
200 ######################################################################
201 # Train
202
203 T = 1000
204 beta = torch.linspace(1e-4, 0.02, T, device = device)
205 alpha = 1 - beta
206 alpha_bar = alpha.log().cumsum(0).exp()
207 sigma = beta.sqrt()
208
209 ema = EMA(model, decay = args.ema_decay)
210
211 for k in range(args.nb_epochs):
212
213     acc_loss = 0
214     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
215
216     for x0 in train_input.split(args.batch_size):
217         x0 = (x0 - train_mean) / train_std
218         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
219         eps = torch.randn_like(x0)
220         input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
221         input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
222         loss = (eps - model(input)).pow(2).mean()
223         acc_loss += loss.item() * x0.size(0)
224
225         optimizer.zero_grad()
226         loss.backward()
227         optimizer.step()
228
229         ema.step()
230
231     if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
232
233 ema.copy()
234
235 ######################################################################
236 # Generate
237
238 def generate(size, model):
239     with torch.no_grad():
240         x = torch.randn(size, device = device)
241
242         for t in range(T-1, -1, -1):
243             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
244             input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
245             x = 1/torch.sqrt(alpha[t]) \
246                 * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
247                 + sigma[t] * z
248
249         x = x * train_std + train_mean
250
251         return x
252
253 ######################################################################
254 # Plot
255
256 model.eval()
257
258 if train_input.dim() == 2:
259     fig = plt.figure()
260     ax = fig.add_subplot(1, 1, 1)
261
262     if train_input.size(1) == 1:
263
264         x = generate((10000, 1), model)
265
266         ax.set_xlim(-1.25, 1.25)
267
268         d = train_input.flatten().detach().to('cpu').numpy()
269         ax.hist(d, 25, (-1, 1),
270                 density = True,
271                 histtype = 'stepfilled', color = 'lightblue', label = 'Train')
272
273         d = x.flatten().detach().to('cpu').numpy()
274         ax.hist(d, 25, (-1, 1),
275                 density = True,
276                 histtype = 'step', color = 'red', label = 'Synthesis')
277
278         ax.legend(frameon = False, loc = 2)
279
280     elif train_input.size(1) == 2:
281
282         x = generate((1000, 2), model)
283
284         ax.set_xlim(-1.25, 1.25)
285         ax.set_ylim(-1.25, 1.25)
286         ax.set(aspect = 1)
287
288         d = train_input[:x.size(0)].detach().to('cpu').numpy()
289         ax.scatter(d[:, 0], d[:, 1],
290                    color = 'lightblue', label = 'Train')
291
292         d = x.detach().to('cpu').numpy()
293         ax.scatter(d[:, 0], d[:, 1],
294                    facecolors = 'none', color = 'red', label = 'Synthesis')
295
296         ax.legend(frameon = False, loc = 2)
297
298     filename = f'diffusion_{args.data}.pdf'
299     print(f'saving {filename}')
300     fig.savefig(filename, bbox_inches='tight')
301
302     if hasattr(plt.get_current_fig_manager(), 'window'):
303         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
304         plt.show()
305
306 elif train_input.dim() == 4:
307     x = generate((128,) + train_input.size()[1:], model)
308     x = 1 - x.clamp(min = 0, max = 255) / 255
309     torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)
310
311 ######################################################################