3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import matplotlib.pyplot as plt
12 import torch, torchvision
14 from torch.nn import functional as F
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 print(f'device {device}')
20 ######################################################################
22 def sample_gaussian_mixture(nb):
24 result = torch.randn(nb, 1) * std
25 result = result + torch.sign(torch.rand(result.size()) - p) / 2
29 result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
32 def sample_two_discs(nb):
33 a = torch.rand(nb) * math.pi * 2
34 b = torch.rand(nb).sqrt()
35 q = (torch.rand(nb) <= 0.5).long()
36 b = b * (0.3 + 0.2 * q)
37 result = torch.empty(nb, 2)
38 result[:, 0] = a.cos() * b - 0.5 + q
39 result[:, 1] = a.sin() * b - 0.5 + q
42 def sample_disc_grid(nb):
43 a = torch.rand(nb) * math.pi * 2
44 b = torch.rand(nb).sqrt()
46 q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
47 r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
49 result = torch.empty(nb, 2)
50 result[:, 0] = a.cos() * b + q
51 result[:, 1] = a.sin() * b + r
54 def sample_spiral(nb):
56 rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
57 theta = u * math.pi * 3
58 result = torch.empty(nb, 2)
59 result[:, 0] = theta.cos() * rho
60 result[:, 1] = theta.sin() * rho
64 train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True)
65 result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
69 f.__name__.removeprefix('sample_') : f for f in [
70 sample_gaussian_mixture,
79 ######################################################################
81 parser = argparse.ArgumentParser(
82 description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
83 "Denoising Diffusion Probabilistic Models" (2020)
84 https://arxiv.org/abs/2006.11239''',
86 formatter_class = argparse.ArgumentDefaultsHelpFormatter
89 parser.add_argument('--seed',
90 type = int, default = 0,
91 help = 'Random seed, < 0 is no seeding')
93 parser.add_argument('--nb_epochs',
94 type = int, default = 100,
95 help = 'How many epochs')
97 parser.add_argument('--batch_size',
98 type = int, default = 25,
101 parser.add_argument('--nb_samples',
102 type = int, default = 25000,
103 help = 'Number of training examples')
105 parser.add_argument('--learning_rate',
106 type = float, default = 1e-3,
107 help = 'Learning rate')
109 parser.add_argument('--ema_decay',
110 type = float, default = 0.9999,
111 help = 'EMA decay, <= 0 is no EMA')
113 data_list = ', '.join( [ str(k) for k in samplers ])
115 parser.add_argument('--data',
116 type = str, default = 'gaussian_mixture',
117 help = f'Toy data-set to use: {data_list}')
119 parser.add_argument('--no_window',
120 action='store_true', default = False)
122 args = parser.parse_args()
125 # torch.backends.cudnn.deterministic = True
126 # torch.backends.cudnn.benchmark = False
127 # torch.use_deterministic_algorithms(True)
128 torch.manual_seed(args.seed)
129 if torch.cuda.is_available():
130 torch.cuda.manual_seed_all(args.seed)
132 ######################################################################
135 def __init__(self, model, decay):
139 with torch.no_grad():
140 for p in model.parameters():
141 self.mem[p] = p.clone()
144 with torch.no_grad():
145 for p in self.model.parameters():
146 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
148 def copy_to_model(self):
149 with torch.no_grad():
150 for p in self.model.parameters():
153 ######################################################################
155 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
156 # additional dimension / channel
158 class TimeAppender(nn.Module):
162 def forward(self, u):
164 if not torch.is_tensor(t):
165 t = x.new_full((x.size(0),), t)
166 t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
167 return torch.cat((x, t), 1)
169 class ConvNet(nn.Module):
170 def __init__(self, in_channels, out_channels):
175 self.core = nn.Sequential(
177 nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
179 nn.Conv2d(nc, nc, ks, padding = ks//2),
181 nn.Conv2d(nc, nc, ks, padding = ks//2),
183 nn.Conv2d(nc, nc, ks, padding = ks//2),
185 nn.Conv2d(nc, nc, ks, padding = ks//2),
187 nn.Conv2d(nc, out_channels, ks, padding = ks//2),
190 def forward(self, u):
193 ######################################################################
197 train_input = samplers[args.data](args.nb_samples).to(device)
199 print(f'unknown data {args.data}')
202 train_mean, train_std = train_input.mean(), train_input.std()
204 ######################################################################
207 if train_input.dim() == 2:
210 model = nn.Sequential(
212 nn.Linear(train_input.size(1) + 1, nh),
218 nn.Linear(nh, train_input.size(1)),
221 elif train_input.dim() == 4:
223 model = ConvNet(train_input.size(1), train_input.size(1))
227 print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
229 ######################################################################
232 def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
234 with torch.no_grad():
236 x = torch.randn(size, device = device)
238 for t in range(T-1, -1, -1):
239 output = model((x, t / (T - 1) - 0.5))
240 z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
241 x = 1/torch.sqrt(alpha[t]) \
242 * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
245 x = x * train_std + train_mean
249 ######################################################################
253 beta = torch.linspace(1e-4, 0.02, T, device = device)
255 alpha_bar = alpha.log().cumsum(0).exp()
258 ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
260 for k in range(args.nb_epochs):
263 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
265 for x0 in train_input.split(args.batch_size):
266 x0 = (x0 - train_mean) / train_std
267 t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
268 eps = torch.randn_like(x0)
269 xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
270 output = model((xt, t / (T - 1) - 0.5))
271 loss = (eps - output).pow(2).mean()
272 acc_loss += loss.item() * x0.size(0)
274 optimizer.zero_grad()
278 if ema is not None: ema.step()
280 print(f'{k} {acc_loss / train_input.size(0)}')
282 if ema is not None: ema.copy_to_model()
284 ######################################################################
289 ########################################
291 if train_input.dim() == 2 and train_input.size(1) == 1:
297 ax = fig.add_subplot(1, 1, 1)
299 x = generate((10000, 1), T, alpha, alpha_bar, sigma,
300 model, train_mean, train_std)
302 ax.set_xlim(-1.25, 1.25)
303 ax.spines.right.set_visible(False)
304 ax.spines.top.set_visible(False)
306 d = train_input.flatten().detach().to('cpu').numpy()
307 ax.hist(d, 25, (-1, 1),
309 histtype = 'bar', edgecolor = 'white', color = 'lightblue', label = 'Train')
311 d = x.flatten().detach().to('cpu').numpy()
312 ax.hist(d, 25, (-1, 1),
314 histtype = 'step', color = 'red', label = 'Synthesis')
316 ax.legend(frameon = False, loc = 2)
318 filename = f'minidiffusion_{args.data}.pdf'
319 print(f'saving {filename}')
320 fig.savefig(filename, bbox_inches='tight')
322 if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
323 plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
326 ########################################
327 # Nx2 -> scatter plot
328 elif train_input.dim() == 2 and train_input.size(1) == 2:
334 ax = fig.add_subplot(1, 1, 1)
336 x = generate((1000, 2), T, alpha, alpha_bar, sigma,
337 model, train_mean, train_std)
339 ax.set_xlim(-1.5, 1.5)
340 ax.set_ylim(-1.5, 1.5)
342 ax.spines.right.set_visible(False)
343 ax.spines.top.set_visible(False)
345 d = train_input[:x.size(0)].detach().to('cpu').numpy()
346 ax.scatter(d[:, 0], d[:, 1],
347 s = 2.5, color = 'gray', label = 'Train')
349 d = x.detach().to('cpu').numpy()
350 ax.scatter(d[:, 0], d[:, 1],
351 s = 2.0, color = 'red', label = 'Synthesis')
353 ax.legend(frameon = False, loc = 2)
355 filename = f'minidiffusion_{args.data}.pdf'
356 print(f'saving {filename}')
357 fig.savefig(filename, bbox_inches='tight')
359 if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
360 plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
363 ########################################
365 elif train_input.dim() == 4:
367 x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
368 model, train_mean, train_std)
370 x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
371 nrow = 16, padding = 1, pad_value = 64)
372 x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
374 t = torchvision.utils.make_grid(train_input[:128],
375 nrow = 16, padding = 1, pad_value = 64)
376 t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
378 result = 1 - torch.cat((t, x), 2) / 255
380 filename = f'minidiffusion_{args.data}.png'
381 print(f'saving {filename}')
382 torchvision.utils.save_image(result, filename)
386 print(f'cannot plot result of size {train_input.size()}')
388 ######################################################################