Update.
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch
13 from torch import nn
14
15 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
17 ######################################################################
18
19 def sample_gaussian_mixture(nb):
20     p, std = 0.3, 0.2
21     result = torch.empty(nb, 1, device = device).normal_(0, std)
22     result = result + torch.sign(torch.rand(result.size(), device = device) - p) / 2
23     return result
24
25 def sample_arc(nb):
26     theta = torch.rand(nb, device = device) * math.pi
27     rho = torch.rand(nb, device = device) * 0.1 + 0.7
28     result = torch.empty(nb, 2, device = device)
29     result[:, 0] = theta.cos() * rho
30     result[:, 1] = theta.sin() * rho
31     return result
32
33 def sample_spiral(nb):
34     u = torch.rand(nb, device = device)
35     rho = u * 0.65 + 0.25 + torch.rand(nb, device = device) * 0.15
36     theta = u * math.pi * 3
37     result = torch.empty(nb, 2, device = device)
38     result[:, 0] = theta.cos() * rho
39     result[:, 1] = theta.sin() * rho
40     return result
41
42 samplers = {
43     'gaussian_mixture': sample_gaussian_mixture,
44     'arc': sample_arc,
45     'spiral': sample_spiral,
46 }
47
48 ######################################################################
49
50 parser = argparse.ArgumentParser(
51     description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
52 "Denoising Diffusion Probabilistic Models" (2020)
53 https://arxiv.org/abs/2006.11239''',
54
55     formatter_class = argparse.ArgumentDefaultsHelpFormatter
56 )
57
58 parser.add_argument('--seed',
59                     type = int, default = 0,
60                     help = 'Random seed, < 0 is no seeding')
61
62 parser.add_argument('--nb_epochs',
63                     type = int, default = 100,
64                     help = 'How many epochs')
65
66 parser.add_argument('--batch_size',
67                     type = int, default = 25,
68                     help = 'Batch size')
69
70 parser.add_argument('--nb_samples',
71                     type = int, default = 25000,
72                     help = 'Number of training examples')
73
74 parser.add_argument('--learning_rate',
75                     type = float, default = 1e-3,
76                     help = 'Learning rate')
77
78 parser.add_argument('--ema_decay',
79                     type = float, default = 0.9999,
80                     help = 'EMA decay, < 0 is no EMA')
81
82 data_list = ', '.join( [ str(k) for k in samplers ])
83
84 parser.add_argument('--data',
85                     type = str, default = 'gaussian_mixture',
86                     help = f'Toy data-set to use: {data_list}')
87
88 args = parser.parse_args()
89
90 if args.seed >= 0:
91     # torch.backends.cudnn.deterministic = True
92     # torch.backends.cudnn.benchmark = False
93     # torch.use_deterministic_algorithms(True)
94     torch.manual_seed(args.seed)
95     if torch.cuda.is_available():
96         torch.cuda.manual_seed_all(args.seed)
97
98 ######################################################################
99
100 class EMA:
101     def __init__(self, model, decay):
102         self.model = model
103         self.decay = decay
104         if self.decay < 0: return
105         self.ema = { }
106         with torch.no_grad():
107             for p in model.parameters():
108                 self.ema[p] = p.clone()
109
110     def step(self):
111         if self.decay < 0: return
112         with torch.no_grad():
113             for p in self.model.parameters():
114                 self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
115
116     def copy(self):
117         if self.decay < 0: return
118         with torch.no_grad():
119             for p in self.model.parameters():
120                 p.copy_(self.ema[p])
121
122 ######################################################################
123 # Train
124
125 try:
126     train_input = samplers[args.data](args.nb_samples)
127 except KeyError:
128     print(f'unknown data {args.data}')
129     exit(1)
130
131 ######################################################################
132
133 nh = 64
134
135 model = nn.Sequential(
136     nn.Linear(train_input.size(1) + 1, nh),
137     nn.ReLU(),
138     nn.Linear(nh, nh),
139     nn.ReLU(),
140     nn.Linear(nh, nh),
141     nn.ReLU(),
142     nn.Linear(nh, train_input.size(1)),
143 ).to(device)
144
145 T = 1000
146 beta = torch.linspace(1e-4, 0.02, T, device = device)
147 alpha = 1 - beta
148 alpha_bar = alpha.log().cumsum(0).exp()
149 sigma = beta.sqrt()
150
151 ema = EMA(model, decay = args.ema_decay)
152
153 for k in range(args.nb_epochs):
154
155     acc_loss = 0
156     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
157
158     for x0 in train_input.split(args.batch_size):
159         t = torch.randint(T, (x0.size(0), 1), device = device)
160         eps = torch.randn(x0.size(), device = device)
161         input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps
162         input = torch.cat((input, 2 * t / T - 1), 1)
163         output = model(input)
164         loss = (eps - output).pow(2).mean()
165         optimizer.zero_grad()
166         loss.backward()
167         optimizer.step()
168
169         acc_loss += loss.item() * x0.size(0)
170
171         ema.step()
172
173     if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
174
175 ema.copy()
176
177 ######################################################################
178 # Generate
179
180 x = torch.randn(10000, train_input.size(1), device = device)
181
182 for t in range(T-1, -1, -1):
183     z = torch.zeros(x.size(), device = device) if t == 0 else torch.randn(x.size(), device = device)
184     input = torch.cat((x, torch.ones(x.size(0), 1, device = device) * 2 * t / T - 1), 1)
185     x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \
186         + sigma[t] * z
187
188 ######################################################################
189 # Plot
190
191 fig = plt.figure()
192 ax = fig.add_subplot(1, 1, 1)
193
194 if train_input.size(1) == 1:
195
196     ax.set_xlim(-1.25, 1.25)
197
198     d = train_input.flatten().detach().to('cpu').numpy()
199     ax.hist(d, 25, (-1, 1),
200             density = True,
201             histtype = 'stepfilled', color = 'lightblue', label = 'Train')
202
203     d = x.flatten().detach().to('cpu').numpy()
204     ax.hist(d, 25, (-1, 1),
205             density = True,
206             histtype = 'step', color = 'red', label = 'Synthesis')
207
208     ax.legend(frameon = False, loc = 2)
209
210 elif train_input.size(1) == 2:
211
212     ax.set_xlim(-1.25, 1.25)
213     ax.set_ylim(-1.25, 1.25)
214     ax.set(aspect = 1)
215
216     d = train_input[:200].detach().to('cpu').numpy()
217     ax.scatter(d[:, 0], d[:, 1],
218                color = 'lightblue', label = 'Train')
219
220     d = x[:200].detach().to('cpu').numpy()
221     ax.scatter(d[:, 0], d[:, 1],
222                color = 'red', label = 'Synthesis')
223
224     ax.legend(frameon = False, loc = 2)
225
226 filename = f'diffusion_{args.data}.pdf'
227 print(f'saving {filename}')
228 fig.savefig(filename, bbox_inches='tight')
229
230 if hasattr(plt.get_current_fig_manager(), 'window'):
231     plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
232     plt.show()
233
234 ######################################################################