+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
+######################################################################
+# Generate
+
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
+ with torch.no_grad():
+
+ x = torch.randn(size, device = device)
+
+ for t in range(T-1, -1, -1):
+ output = model((x, t / (T - 1) - 0.5))
+ z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
+ x = 1/torch.sqrt(alpha[t]) \
+ * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
+ + sigma[t] * z
+
+ x = x * train_std + train_mean
+
+ return x
+