+ x = generate(
+ (128,) + train_input.size()[1:],
+ T,
+ alpha,
+ alpha_bar,
+ sigma,
+ model,
+ train_mean,
+ train_std,
+ )
+
+ x = torchvision.utils.make_grid(
+ x.clamp(min=0, max=255), nrow=16, padding=1, pad_value=64
+ )
+ x = F.pad(x, pad=(2, 2, 2, 2), value=64)[None]
+
+ t = torchvision.utils.make_grid(train_input[:128], nrow=16, padding=1, pad_value=64)
+ t = F.pad(t, pad=(2, 2, 2, 2), value=64)[None]
+
+ result = 1 - torch.cat((t, x), 2) / 255
+
+ filename = f"minidiffusion_{args.data}.png"
+ print(f"saving {filename}")
+ torchvision.utils.save_image(result, filename)