def save_image(self, input, result_dir, filename, logger):
img = world.sample2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
def make_ar_mask(self, input):
self.batch_size = batch_size
self.device = device
- self.height = 6
- self.width = 8
+ self.height = 7
+ self.width = 9
self.train_input = world.generate(
nb_train_samples, height=self.height, width=self.width
if result_dir is not None:
self.save_image(
- self.train_input[:96], result_dir, f"world_train.png", logger
+ self.train_input[:72], result_dir, f"world_train.png", logger
)
def batches(self, split="train", desc=None):
)
self.save_image(
- result[:96],
+ result[:72],
result_dir,
f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,