X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=ecf0a652185f5733a1bb175b64db8e61b89b0276;hb=fdce96c3960f5f544bd68c0f18cc5cd096ecbfb3;hp=9a67127bec3295904ec44d0c14cec4babf4a9136;hpb=4ec52fe66419a6e1d2b231108ccbb45902395fcc;p=culture.git diff --git a/tasks.py b/tasks.py index 9a67127..ecf0a65 100755 --- a/tasks.py +++ b/tasks.py @@ -81,7 +81,7 @@ class World(Task): def save_image(self, input, result_dir, filename, logger): img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") def make_ar_mask(self, input): @@ -104,11 +104,11 @@ class World(Task): self.height = 6 self.width = 8 - self.train_input = world.generate( + self.train_input = world.generate_seq( nb_train_samples, height=self.height, width=self.width ).to(device) - self.test_input = world.generate( + self.test_input = world.generate_seq( nb_test_samples, height=self.height, width=self.width ).to(device) @@ -126,7 +126,7 @@ class World(Task): if result_dir is not None: self.save_image( - self.train_input[:96], result_dir, f"world_train.png", logger + self.train_input[:72], result_dir, f"world_train.png", logger ) def batches(self, split="train", desc=None): @@ -222,7 +222,7 @@ class World(Task): ) self.save_image( - result[:96], + result[:72], result_dir, f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger,