# Written by Francois Fleuret <francois@fleuret.org>
+# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
+# "Denoising Diffusion Probabilistic Models" (2020)
+#
+# https://arxiv.org/abs/2006.11239
+
import matplotlib.pyplot as plt
import torch
from torch import nn
if k%10 == 0: print(k, loss.item())
######################################################################
-# Plot
+# Generate
x = torch.randn(10000, 1)
input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z
+######################################################################
+# Plot
+
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(-1.25, 1.25)
d = train_input.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train')
+ax.hist(d, 25, (-1, 1),
+ density = True,
+ histtype = 'stepfilled', color = 'lightblue', label = 'Train')
d = x.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis')
+ax.hist(d, 25, (-1, 1),
+ density = True,
+ histtype = 'step', color = 'red', label = 'Synthesis')
ax.legend(frameon = False, loc = 2)
filename = 'diffusion.pdf'
+print(f'saving {filename}')
fig.savefig(filename, bbox_inches='tight')
plt.show()