- x = generate((128,) + train_input.size()[1:], model)
- x = 1 - x.clamp(min = 0, max = 255) / 255
- torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)
+ x = generate(
+ (128,) + train_input.size()[1:],
+ T,
+ alpha,
+ alpha_bar,
+ sigma,
+ model,
+ train_mean,
+ train_std,
+ )
+
+ x = torchvision.utils.make_grid(
+ x.clamp(min=0, max=255), nrow=16, padding=1, pad_value=64
+ )
+ x = F.pad(x, pad=(2, 2, 2, 2), value=64)[None]
+
+ t = torchvision.utils.make_grid(train_input[:128], nrow=16, padding=1, pad_value=64)
+ t = F.pad(t, pad=(2, 2, 2, 2), value=64)[None]
+
+ result = 1 - torch.cat((t, x), 2) / 255
+
+ filename = f"minidiffusion_{args.data}.png"
+ print(f"saving {filename}")
+ torchvision.utils.save_image(result, filename)
+
+else:
+ print(f"cannot plot result of size {train_input.size()}")