Update.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
9 # "Denoising Diffusion Probabilistic Models" (2020)
10 #
11 # https://arxiv.org/abs/2006.11239
12
13 import matplotlib.pyplot as plt
14 import torch
15 from torch import nn
16
17 ######################################################################
18
19 def sample_phi(nb):
20     p, std = 0.3, 0.2
21     result = torch.empty(nb).normal_(0, std)
22     result = result + torch.sign(torch.rand(result.size()) - p) / 2
23     return result
24
25 ######################################################################
26
27 model = nn.Sequential(
28     nn.Linear(2, 32),
29     nn.ReLU(),
30     nn.Linear(32, 32),
31     nn.ReLU(),
32     nn.Linear(32, 1),
33 )
34
35 ######################################################################
36 # Train
37
38 nb_samples = 25000
39 nb_epochs = 250
40 batch_size = 100
41
42 train_input = sample_phi(nb_samples)[:, None]
43
44 T = 1000
45 beta = torch.linspace(1e-4, 0.02, T)
46 alpha = 1 - beta
47 alpha_bar = alpha.log().cumsum(0).exp()
48 sigma = beta.sqrt()
49
50 for k in range(nb_epochs):
51
52     acc_loss = 0
53     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) )
54
55     for x0 in train_input.split(batch_size):
56         t = torch.randint(T, (x0.size(0), 1))
57         eps = torch.randn(x0.size())
58         input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps
59         input = torch.cat((input, 2 * t / T - 1), 1)
60         output = model(input)
61         loss = (eps - output).pow(2).mean()
62         optimizer.zero_grad()
63         loss.backward()
64         optimizer.step()
65
66         acc_loss += loss.item()
67
68     if k%10 == 0: print(k, loss.item())
69
70 ######################################################################
71 # Generate
72
73 x = torch.randn(10000, 1)
74
75 for t in range(T-1, -1, -1):
76     z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size())
77     input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
78     x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \
79         + sigma[t] * z
80
81 ######################################################################
82 # Plot
83
84 fig = plt.figure()
85 ax = fig.add_subplot(1, 1, 1)
86 ax.set_xlim(-1.25, 1.25)
87
88 d = train_input.flatten().detach().numpy()
89 ax.hist(d, 25, (-1, 1),
90         density = True,
91         histtype = 'stepfilled', color = 'lightblue', label = 'Train')
92
93 d = x.flatten().detach().numpy()
94 ax.hist(d, 25, (-1, 1),
95         density = True,
96         histtype = 'step', color = 'red', label = 'Synthesis')
97
98 ax.legend(frameon = False, loc = 2)
99
100 filename = 'diffusion.pdf'
101 print(f'saving {filename}')
102 fig.savefig(filename, bbox_inches='tight')
103
104 plt.show()
105
106 ######################################################################