Update.
[pytorch.git] / minidiffusion.py
index 037ef11..a386a12 100755 (executable)
@@ -48,6 +48,7 @@ alpha_bar = alpha.log().cumsum(0).exp()
 sigma = beta.sqrt()
 
 for k in range(nb_epochs):
+
     acc_loss = 0
     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) )
 
@@ -74,7 +75,8 @@ x = torch.randn(10000, 1)
 for t in range(T-1, -1, -1):
     z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size())
     input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
-    x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z
+    x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \
+        + sigma[t] * z
 
 ######################################################################
 # Plot