def save_image(x, filename):
x = x * train_std + train_mu
x = x.clamp(min=0, max=255) / 255
- torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+ torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0)
log_string(f"wrote {filename}")
# Save a bunch of train images
- x = train_input[:256]
+ x = train_input[:36]
save_image(x, f"{prefix}train_input.png")
# Save the same images after encoding / decoding
# Save a bunch of test images
- x = test_input[:256]
+ x = test_input[:36]
save_image(x, f"{prefix}input.png")
# Save the same images after encoding / decoding