# Written by Francois Fleuret <francois@fleuret.org>
-# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
-# "Denoising Diffusion Probabilistic Models" (2020)
-#
-# https://arxiv.org/abs/2006.11239
+import math, argparse
-import math
import matplotlib.pyplot as plt
+
import torch
from torch import nn
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+######################################################################
+
+def sample_gaussian_mixture(nb):
+ p, std = 0.3, 0.2
+ result = torch.empty(nb, 1, device = device).normal_(0, std)
+ result = result + torch.sign(torch.rand(result.size(), device = device) - p) / 2
+ return result
+
+def sample_arc(nb):
+ theta = torch.rand(nb, device = device) * math.pi
+ rho = torch.rand(nb, device = device) * 0.1 + 0.7
+ result = torch.empty(nb, 2, device = device)
+ result[:, 0] = theta.cos() * rho
+ result[:, 1] = theta.sin() * rho
+ return result
+
+def sample_spiral(nb):
+ u = torch.rand(nb, device = device)
+ rho = u * 0.65 + 0.25 + torch.rand(nb, device = device) * 0.15
+ theta = u * math.pi * 3
+ result = torch.empty(nb, 2, device = device)
+ result[:, 0] = theta.cos() * rho
+ result[:, 1] = theta.sin() * rho
+ return result
+
+samplers = {
+ 'gaussian_mixture': sample_gaussian_mixture,
+ 'arc': sample_arc,
+ 'spiral': sample_spiral,
+}
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
+"Denoising Diffusion Probabilistic Models" (2020)
+https://arxiv.org/abs/2006.11239''',
+
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed, < 0 is no seeding')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 100,
+ help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 25,
+ help = 'Batch size')
+
+parser.add_argument('--nb_samples',
+ type = int, default = 25000,
+ help = 'Number of training examples')
+
+parser.add_argument('--learning_rate',
+ type = float, default = 1e-3,
+ help = 'Learning rate')
+
+parser.add_argument('--ema_decay',
+ type = float, default = 0.9999,
+ help = 'EMA decay, < 0 means no EMA')
+
+data_list = ', '.join( [ str(k) for k in samplers ])
+
+parser.add_argument('--data',
+ type = str, default = 'gaussian_mixture',
+ help = f'Toy data-set to use: {data_list}')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # torch.use_deterministic_algorithms(True)
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
######################################################################
class EMA:
- def __init__(self, model, decay = 0.9999):
+ def __init__(self, model, decay):
self.model = model
self.decay = decay
+ if self.decay < 0: return
self.ema = { }
with torch.no_grad():
for p in model.parameters():
self.ema[p] = p.clone()
def step(self):
+ if self.decay < 0: return
with torch.no_grad():
for p in self.model.parameters():
self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
def copy(self):
+ if self.decay < 0: return
with torch.no_grad():
for p in self.model.parameters():
p.copy_(self.ema[p])
-######################################################################
-
-def sample_gaussian_mixture(nb):
- p, std = 0.3, 0.2
- result = torch.empty(nb, 1).normal_(0, std)
- result = result + torch.sign(torch.rand(result.size()) - p) / 2
- return result
-
-def sample_arc(nb):
- theta = torch.rand(nb) * math.pi
- rho = torch.rand(nb) * 0.1 + 0.7
- result = torch.empty(nb, 2)
- result[:, 0] = theta.cos() * rho
- result[:, 1] = theta.sin() * rho
- return result
-
######################################################################
# Train
-nb_samples = 25000
-
-train_input = sample_gaussian_mixture(nb_samples)
-#train_input = sample_arc(nb_samples)
+try:
+ train_input = samplers[args.data](args.nb_samples)
+except KeyError:
+ print(f'unknown data {args.data}')
+ exit(1)
######################################################################
nn.ReLU(),
nn.Linear(nh, nh),
nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
nn.Linear(nh, train_input.size(1)),
-)
-
-nb_epochs = 50
-batch_size = 25
+).to(device)
T = 1000
-beta = torch.linspace(1e-4, 0.02, T)
+beta = torch.linspace(1e-4, 0.02, T, device = device)
alpha = 1 - beta
alpha_bar = alpha.log().cumsum(0).exp()
sigma = beta.sqrt()
-ema = EMA(model)
+ema = EMA(model, decay = args.ema_decay)
-for k in range(nb_epochs):
+for k in range(args.nb_epochs):
acc_loss = 0
- optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+ optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
- for x0 in train_input.split(batch_size):
- t = torch.randint(T, (x0.size(0), 1))
- eps = torch.randn(x0.size())
+ for x0 in train_input.split(args.batch_size):
+ t = torch.randint(T, (x0.size(0), 1), device = device)
+ eps = torch.randn(x0.size(), device = device)
input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps
input = torch.cat((input, 2 * t / T - 1), 1)
output = model(input)
loss.backward()
optimizer.step()
- acc_loss += loss.item()
+ acc_loss += loss.item() * x0.size(0)
ema.step()
- if k%10 == 0: print(k, loss.item())
+ if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
ema.copy()
######################################################################
# Generate
-x = torch.randn(10000, train_input.size(1))
+x = torch.randn(10000, train_input.size(1), device = device)
for t in range(T-1, -1, -1):
- z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size())
- input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
+ z = torch.zeros(x.size(), device = device) if t == 0 else torch.randn(x.size(), device = device)
+ input = torch.cat((x, torch.ones(x.size(0), 1, device = device) * 2 * t / T - 1), 1)
x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \
+ sigma[t] * z
ax.set_xlim(-1.25, 1.25)
- d = train_input.flatten().detach().numpy()
+ d = train_input.flatten().detach().to('cpu').numpy()
ax.hist(d, 25, (-1, 1),
density = True,
histtype = 'stepfilled', color = 'lightblue', label = 'Train')
- d = x.flatten().detach().numpy()
+ d = x.flatten().detach().to('cpu').numpy()
ax.hist(d, 25, (-1, 1),
density = True,
histtype = 'step', color = 'red', label = 'Synthesis')
ax.set_ylim(-1.25, 1.25)
ax.set(aspect = 1)
- d = train_input[:200].detach().numpy()
+ d = train_input[:200].detach().to('cpu').numpy()
ax.scatter(d[:, 0], d[:, 1],
color = 'lightblue', label = 'Train')
- d = x[:200].detach().numpy()
+ d = x[:200].detach().to('cpu').numpy()
ax.scatter(d[:, 0], d[:, 1],
color = 'red', label = 'Synthesis')
ax.legend(frameon = False, loc = 2)
-filename = 'diffusion.pdf'
+filename = f'diffusion_{args.data}.pdf'
print(f'saving {filename}')
fig.savefig(filename, bbox_inches='tight')
-plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
-plt.show()
+if hasattr(plt.get_current_fig_manager(), 'window'):
+ plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
+ plt.show()
######################################################################