X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=6855752e6e67b2e2c53b179617685caf715f24b8;hb=317cc211cf9589a9eee5d937f0d0182719f24790;hp=cbdb1425b090d7f13b33eaa6779e8a9177821d70;hpb=bc937c74ad8cbeede2c2ae97cda72eaa3e9bb4f3;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index cbdb142..6855752 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +# "Denoising Diffusion Probabilistic Models" (2020) +# +# https://arxiv.org/abs/2006.11239 + import matplotlib.pyplot as plt import torch from torch import nn @@ -62,26 +67,37 @@ for k in range(nb_epochs): if k%10 == 0: print(k, loss.item()) ###################################################################### -# Plot +# Generate x = torch.randn(10000, 1) for t in range(T-1, -1, -1): z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) - x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z + x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ + + sigma[t] * z + +###################################################################### +# Plot fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.set_xlim(-1.25, 1.25) d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') + +ax.legend(frameon = False, loc = 2) filename = 'diffusion.pdf' +print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') plt.show()