projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
b740a73
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Fri, 12 Aug 2022 07:57:09 +0000
(09:57 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Fri, 12 Aug 2022 07:57:09 +0000
(09:57 +0200)
minidiffusion.py
patch
|
blob
|
history
diff --git
a/minidiffusion.py
b/minidiffusion.py
index
ad1cda0
..
037ef11
100755
(executable)
--- a/
minidiffusion.py
+++ b/
minidiffusion.py
@@
-5,6
+5,11
@@
# Written by Francois Fleuret <francois@fleuret.org>
# Written by Francois Fleuret <francois@fleuret.org>
+# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
+# "Denoising Diffusion Probabilistic Models" (2020)
+#
+# https://arxiv.org/abs/2006.11239
+
import matplotlib.pyplot as plt
import torch
from torch import nn
import matplotlib.pyplot as plt
import torch
from torch import nn
@@
-62,7
+67,7
@@
for k in range(nb_epochs):
if k%10 == 0: print(k, loss.item())
######################################################################
if k%10 == 0: print(k, loss.item())
######################################################################
-#
Plot
+#
Generate
x = torch.randn(10000, 1)
x = torch.randn(10000, 1)
@@
-71,19
+76,27
@@
for t in range(T-1, -1, -1):
input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z
input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z
+######################################################################
+# Plot
+
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(-1.25, 1.25)
d = train_input.flatten().detach().numpy()
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(-1.25, 1.25)
d = train_input.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train')
+ax.hist(d, 25, (-1, 1),
+ density = True,
+ histtype = 'stepfilled', color = 'lightblue', label = 'Train')
d = x.flatten().detach().numpy()
d = x.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis')
+ax.hist(d, 25, (-1, 1),
+ density = True,
+ histtype = 'step', color = 'red', label = 'Synthesis')
ax.legend(frameon = False, loc = 2)
filename = 'diffusion.pdf'
ax.legend(frameon = False, loc = 2)
filename = 'diffusion.pdf'
+print(f'saving {filename}')
fig.savefig(filename, bbox_inches='tight')
plt.show()
fig.savefig(filename, bbox_inches='tight')
plt.show()