+ filename = f'minidiffusion_{args.data}.pdf'
+ print(f'saving {filename}')
+ fig.savefig(filename, bbox_inches='tight')
+
+ if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
+ plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
+ plt.show()
+
+########################################
+# NxCxHxW -> image
+elif train_input.dim() == 4:
+
+ x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
+ model, train_mean, train_std)
+
+ x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
+ nrow = 16, padding = 1, pad_value = 64)
+ x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
+
+ t = torchvision.utils.make_grid(train_input[:128],
+ nrow = 16, padding = 1, pad_value = 64)
+ t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
+
+ result = 1 - torch.cat((t, x), 2) / 255
+
+ filename = f'minidiffusion_{args.data}.png'
+ print(f'saving {filename}')
+ torchvision.utils.save_image(result, filename)
+
+else: