X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=ad952377ed6a6d1610bc2923985d42b295bc04d5;hb=55aedeed5cb8f0b61a625e64dcaeb0c1fd21d9f6;hp=9a67127bec3295904ec44d0c14cec4babf4a9136;hpb=4ec52fe66419a6e1d2b231108ccbb45902395fcc;p=culture.git diff --git a/tasks.py b/tasks.py index 9a67127..ad95237 100755 --- a/tasks.py +++ b/tasks.py @@ -81,7 +81,7 @@ class World(Task): def save_image(self, input, result_dir, filename, logger): img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") def make_ar_mask(self, input): @@ -101,8 +101,8 @@ class World(Task): self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 + self.height = 7 + self.width = 9 self.train_input = world.generate( nb_train_samples, height=self.height, width=self.width @@ -126,7 +126,7 @@ class World(Task): if result_dir is not None: self.save_image( - self.train_input[:96], result_dir, f"world_train.png", logger + self.train_input[:72], result_dir, f"world_train.png", logger ) def batches(self, split="train", desc=None): @@ -222,7 +222,7 @@ class World(Task): ) self.save_image( - result[:96], + result[:72], result_dir, f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger,