X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=6855752e6e67b2e2c53b179617685caf715f24b8;hb=317cc211cf9589a9eee5d937f0d0182719f24790;hp=037ef11f0b3b5ebb312d35b04f690d686e249706;hpb=f27d6083fbe7243f5896ddd49587fe1923fe9a79;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 037ef11..6855752 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -74,7 +74,8 @@ x = torch.randn(10000, 1) for t in range(T-1, -1, -1): z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) - x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z + x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ + + sigma[t] * z ###################################################################### # Plot