Update.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
9 # "Denoising Diffusion Probabilistic Models" (2020)
10 #
11 # https://arxiv.org/abs/2006.11239
12
13 import math
14 import matplotlib.pyplot as plt
15 import torch
16 from torch import nn
17
18 ######################################################################
19
20 class EMA:
21     def __init__(self, model, decay = 0.9999):
22         self.model = model
23         self.decay = decay
24         self.ema = { }
25         with torch.no_grad():
26             for p in model.parameters():
27                 self.ema[p] = p.clone()
28
29     def step(self):
30         with torch.no_grad():
31             for p in self.model.parameters():
32                 self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
33
34     def copy(self):
35         with torch.no_grad():
36             for p in self.model.parameters():
37                 p.copy_(self.ema[p])
38
39 ######################################################################
40
41 def sample_gaussian_mixture(nb):
42     p, std = 0.3, 0.2
43     result = torch.empty(nb, 1).normal_(0, std)
44     result = result + torch.sign(torch.rand(result.size()) - p) / 2
45     return result
46
47 def sample_arc(nb):
48     theta = torch.rand(nb) * math.pi
49     rho = torch.rand(nb) * 0.1 + 0.7
50     result = torch.empty(nb, 2)
51     result[:, 0] = theta.cos() * rho
52     result[:, 1] = theta.sin() * rho
53     return result
54
55 ######################################################################
56 # Train
57
58 nb_samples = 25000
59
60 train_input = sample_gaussian_mixture(nb_samples)
61 #train_input = sample_arc(nb_samples)
62
63 ######################################################################
64
65 nh = 64
66
67 model = nn.Sequential(
68     nn.Linear(train_input.size(1) + 1, nh),
69     nn.ReLU(),
70     nn.Linear(nh, nh),
71     nn.ReLU(),
72     nn.Linear(nh, train_input.size(1)),
73 )
74
75 nb_epochs = 50
76 batch_size = 25
77
78 T = 1000
79 beta = torch.linspace(1e-4, 0.02, T)
80 alpha = 1 - beta
81 alpha_bar = alpha.log().cumsum(0).exp()
82 sigma = beta.sqrt()
83
84 ema = EMA(model)
85
86 for k in range(nb_epochs):
87
88     acc_loss = 0
89     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
90
91     for x0 in train_input.split(batch_size):
92         t = torch.randint(T, (x0.size(0), 1))
93         eps = torch.randn(x0.size())
94         input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps
95         input = torch.cat((input, 2 * t / T - 1), 1)
96         output = model(input)
97         loss = (eps - output).pow(2).mean()
98         optimizer.zero_grad()
99         loss.backward()
100         optimizer.step()
101
102         acc_loss += loss.item()
103
104         ema.step()
105
106     if k%10 == 0: print(k, loss.item())
107
108 ema.copy()
109
110 ######################################################################
111 # Generate
112
113 x = torch.randn(10000, train_input.size(1))
114
115 for t in range(T-1, -1, -1):
116     z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size())
117     input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
118     x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \
119         + sigma[t] * z
120
121 ######################################################################
122 # Plot
123
124 fig = plt.figure()
125 ax = fig.add_subplot(1, 1, 1)
126
127 if train_input.size(1) == 1:
128
129     ax.set_xlim(-1.25, 1.25)
130
131     d = train_input.flatten().detach().numpy()
132     ax.hist(d, 25, (-1, 1),
133             density = True,
134             histtype = 'stepfilled', color = 'lightblue', label = 'Train')
135
136     d = x.flatten().detach().numpy()
137     ax.hist(d, 25, (-1, 1),
138             density = True,
139             histtype = 'step', color = 'red', label = 'Synthesis')
140
141     ax.legend(frameon = False, loc = 2)
142
143 elif train_input.size(1) == 2:
144
145     ax.set_xlim(-1.25, 1.25)
146     ax.set_ylim(-1.25, 1.25)
147     ax.set(aspect = 1)
148
149     d = train_input[:200].detach().numpy()
150     ax.scatter(d[:, 0], d[:, 1],
151                color = 'lightblue', label = 'Train')
152
153     d = x[:200].detach().numpy()
154     ax.scatter(d[:, 0], d[:, 1],
155                color = 'red', label = 'Synthesis')
156
157     ax.legend(frameon = False, loc = 2)
158
159 filename = 'diffusion.pdf'
160 print(f'saving {filename}')
161 fig.savefig(filename, bbox_inches='tight')
162
163 plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
164 plt.show()
165
166 ######################################################################