Update.
[culture.git] / tasks.py
index 9a67127..ecf0a65 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -81,7 +81,7 @@ class World(Task):
     def save_image(self, input, result_dir, filename, logger):
         img = world.sample2img(input.to("cpu"), self.height, self.width)
         image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
         logger(f"wrote {image_name}")
 
     def make_ar_mask(self, input):
@@ -104,11 +104,11 @@ class World(Task):
         self.height = 6
         self.width = 8
 
-        self.train_input = world.generate(
+        self.train_input = world.generate_seq(
             nb_train_samples, height=self.height, width=self.width
         ).to(device)
 
-        self.test_input = world.generate(
+        self.test_input = world.generate_seq(
             nb_test_samples, height=self.height, width=self.width
         ).to(device)
 
@@ -126,7 +126,7 @@ class World(Task):
 
         if result_dir is not None:
             self.save_image(
-                self.train_input[:96], result_dir, f"world_train.png", logger
+                self.train_input[:72], result_dir, f"world_train.png", logger
             )
 
     def batches(self, split="train", desc=None):
@@ -222,7 +222,7 @@ class World(Task):
         )
 
         self.save_image(
-            result[:96],
+            result[:72],
             result_dir,
             f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
             logger,