+######################################################################
+# Generate
+
+def generate(size, alpha, alpha_bar, sigma, model):
+ with torch.no_grad():
+ x = torch.randn(size, device = device)
+
+ for t in range(T-1, -1, -1):
+ z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
+ input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
+ x = 1/torch.sqrt(alpha[t]) \
+ * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
+ + sigma[t] * z
+
+ x = x * train_std + train_mean
+
+ return x
+