3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import matplotlib.pyplot as plt
12 import torch, torchvision
14 from torch.nn import functional as F
16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 print(f"device {device}")
20 ######################################################################
23 def sample_gaussian_mixture(nb):
25 result = torch.randn(nb, 1) * std
26 result = result + torch.sign(torch.rand(result.size()) - p) / 2
31 result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
35 def sample_two_discs(nb):
36 a = torch.rand(nb) * math.pi * 2
37 b = torch.rand(nb).sqrt()
38 q = (torch.rand(nb) <= 0.5).long()
39 b = b * (0.3 + 0.2 * q)
40 result = torch.empty(nb, 2)
41 result[:, 0] = a.cos() * b - 0.5 + q
42 result[:, 1] = a.sin() * b - 0.5 + q
46 def sample_disc_grid(nb):
47 a = torch.rand(nb) * math.pi * 2
48 b = torch.rand(nb).sqrt()
50 q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
51 r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
53 result = torch.empty(nb, 2)
54 result[:, 0] = a.cos() * b + q
55 result[:, 1] = a.sin() * b + r
59 def sample_spiral(nb):
61 rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
62 theta = u * math.pi * 3
63 result = torch.empty(nb, 2)
64 result[:, 0] = theta.cos() * rho
65 result[:, 1] = theta.sin() * rho
70 train_set = torchvision.datasets.MNIST(root="./data/", train=True, download=True)
71 result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
76 f.__name__.removeprefix("sample_"): f
78 sample_gaussian_mixture,
87 ######################################################################
89 parser = argparse.ArgumentParser(
90 description="""A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
91 "Denoising Diffusion Probabilistic Models" (2020)
92 https://arxiv.org/abs/2006.11239""",
93 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
97 "--seed", type=int, default=0, help="Random seed, < 0 is no seeding"
100 parser.add_argument("--nb_epochs", type=int, default=100, help="How many epochs")
102 parser.add_argument("--batch_size", type=int, default=25, help="Batch size")
105 "--nb_samples", type=int, default=25000, help="Number of training examples"
108 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
111 "--ema_decay", type=float, default=0.9999, help="EMA decay, <= 0 is no EMA"
114 data_list = ", ".join([str(k) for k in samplers])
119 default="gaussian_mixture",
120 help=f"Toy data-set to use: {data_list}",
123 parser.add_argument("--no_window", action="store_true", default=False)
125 args = parser.parse_args()
128 # torch.backends.cudnn.deterministic = True
129 # torch.backends.cudnn.benchmark = False
130 # torch.use_deterministic_algorithms(True)
131 torch.manual_seed(args.seed)
132 if torch.cuda.is_available():
133 torch.cuda.manual_seed_all(args.seed)
135 ######################################################################
139 def __init__(self, model, decay):
143 with torch.no_grad():
144 for p in model.parameters():
145 self.mem[p] = p.clone()
148 with torch.no_grad():
149 for p in self.model.parameters():
150 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
152 def copy_to_model(self):
153 with torch.no_grad():
154 for p in self.model.parameters():
158 ######################################################################
160 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
161 # additional dimension / channel
164 class TimeAppender(nn.Module):
168 def forward(self, u):
170 if not torch.is_tensor(t):
171 t = x.new_full((x.size(0),), t)
172 t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:, :1])
173 return torch.cat((x, t), 1)
176 class ConvNet(nn.Module):
177 def __init__(self, in_channels, out_channels):
182 self.core = nn.Sequential(
184 nn.Conv2d(in_channels + 1, nc, ks, padding=ks // 2),
186 nn.Conv2d(nc, nc, ks, padding=ks // 2),
188 nn.Conv2d(nc, nc, ks, padding=ks // 2),
190 nn.Conv2d(nc, nc, ks, padding=ks // 2),
192 nn.Conv2d(nc, nc, ks, padding=ks // 2),
194 nn.Conv2d(nc, out_channels, ks, padding=ks // 2),
197 def forward(self, u):
201 ######################################################################
205 train_input = samplers[args.data](args.nb_samples).to(device)
207 print(f"unknown data {args.data}")
210 train_mean, train_std = train_input.mean(), train_input.std()
212 ######################################################################
215 if train_input.dim() == 2:
218 model = nn.Sequential(
220 nn.Linear(train_input.size(1) + 1, nh),
226 nn.Linear(nh, train_input.size(1)),
229 elif train_input.dim() == 4:
230 model = ConvNet(train_input.size(1), train_input.size(1))
234 print(f"nb_parameters {sum([ p.numel() for p in model.parameters() ])}")
236 ######################################################################
240 def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
241 with torch.no_grad():
242 x = torch.randn(size, device=device)
244 for t in range(T - 1, -1, -1):
245 output = model((x, t / (T - 1) - 0.5))
246 z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
249 / torch.sqrt(alpha[t])
250 * (x - (1 - alpha[t]) / torch.sqrt(1 - alpha_bar[t]) * output)
254 x = x * train_std + train_mean
259 ######################################################################
263 beta = torch.linspace(1e-4, 0.02, T, device=device)
265 alpha_bar = alpha.log().cumsum(0).exp()
268 ema = EMA(model, decay=args.ema_decay) if args.ema_decay > 0 else None
270 for k in range(args.nb_epochs):
272 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
274 for x0 in train_input.split(args.batch_size):
275 x0 = (x0 - train_mean) / train_std
276 t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device=x0.device)
277 eps = torch.randn_like(x0)
278 xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
279 output = model((xt, t / (T - 1) - 0.5))
280 loss = (eps - output).pow(2).mean()
281 acc_loss += loss.item() * x0.size(0)
283 optimizer.zero_grad()
290 print(f"{k} {acc_loss / train_input.size(0)}")
295 ######################################################################
300 ########################################
302 if train_input.dim() == 2 and train_input.size(1) == 1:
307 ax = fig.add_subplot(1, 1, 1)
309 x = generate((10000, 1), T, alpha, alpha_bar, sigma, model, train_mean, train_std)
311 ax.set_xlim(-1.25, 1.25)
312 ax.spines.right.set_visible(False)
313 ax.spines.top.set_visible(False)
315 d = train_input.flatten().detach().to("cpu").numpy()
327 d = x.flatten().detach().to("cpu").numpy()
329 d, 25, (-1, 1), density=True, histtype="step", color="red", label="Synthesis"
332 ax.legend(frameon=False, loc=2)
334 filename = f"minidiffusion_{args.data}.pdf"
335 print(f"saving {filename}")
336 fig.savefig(filename, bbox_inches="tight")
338 if not args.no_window and hasattr(plt.get_current_fig_manager(), "window"):
339 plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
342 ########################################
343 # Nx2 -> scatter plot
344 elif train_input.dim() == 2 and train_input.size(1) == 2:
349 ax = fig.add_subplot(1, 1, 1)
351 x = generate((1000, 2), T, alpha, alpha_bar, sigma, model, train_mean, train_std)
353 ax.set_xlim(-1.5, 1.5)
354 ax.set_ylim(-1.5, 1.5)
356 ax.spines.right.set_visible(False)
357 ax.spines.top.set_visible(False)
359 d = train_input[: x.size(0)].detach().to("cpu").numpy()
360 ax.scatter(d[:, 0], d[:, 1], s=2.5, color="gray", label="Train")
362 d = x.detach().to("cpu").numpy()
363 ax.scatter(d[:, 0], d[:, 1], s=2.0, color="red", label="Synthesis")
365 ax.legend(frameon=False, loc=2)
367 filename = f"minidiffusion_{args.data}.pdf"
368 print(f"saving {filename}")
369 fig.savefig(filename, bbox_inches="tight")
371 if not args.no_window and hasattr(plt.get_current_fig_manager(), "window"):
372 plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
375 ########################################
377 elif train_input.dim() == 4:
379 (128,) + train_input.size()[1:],
389 x = torchvision.utils.make_grid(
390 x.clamp(min=0, max=255), nrow=16, padding=1, pad_value=64
392 x = F.pad(x, pad=(2, 2, 2, 2), value=64)[None]
394 t = torchvision.utils.make_grid(train_input[:128], nrow=16, padding=1, pad_value=64)
395 t = F.pad(t, pad=(2, 2, 2, 2), value=64)[None]
397 result = 1 - torch.cat((t, x), 2) / 255
399 filename = f"minidiffusion_{args.data}.png"
400 print(f"saving {filename}")
401 torchvision.utils.save_image(result, filename)
404 print(f"cannot plot result of size {train_input.size()}")
406 ######################################################################