- x = generate((128,) + train_input.size()[1:], model)
- x = 1 - x.clamp(min = 0, max = 255) / 255
- torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)
+
+ x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
+ model, train_mean, train_std)
+
+ x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
+ nrow = 16, padding = 1, pad_value = 64)
+ x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
+
+ t = torchvision.utils.make_grid(train_input[:128],
+ nrow = 16, padding = 1, pad_value = 64)
+ t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
+
+ result = 1 - torch.cat((t, x), 2) / 255
+
+ filename = f'minidiffusion_{args.data}.png'
+ print(f'saving {filename}')
+ torchvision.utils.save_image(result, filename)
+
+else:
+
+ print(f'cannot plot result of size {train_input.size()}')