Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 14:45:09 +0000 (16:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 14:45:09 +0000 (16:45 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 35f02a3..3acf595 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -474,6 +474,7 @@ elif args.task == "world":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
+        result_dir=args.result_dir,
         logger=log_string,
         device=device,
     )
@@ -902,7 +903,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     # --------------------------------------------
 
     if n_epoch >= 3:
-        nb_required = 1000
+        nb_required = 100
         kept = []
 
         while sum([x.size(0) for x in kept]) < nb_required:
@@ -920,6 +921,13 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
             kept.append(to_keep)
 
         new_problems = torch.cat(kept, dim=0)[:nb_required]
+        task.store_new_problems(new_problems)
+        task.save_image(
+            new_problems[:96],
+            args.result_dir,
+            f"world_new_{n_epoch:04d}.png",
+            log_string,
+        )
 
     # --------------------------------------------
 
index 1b28108..1a6c415 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2099,11 +2099,18 @@ import world
 
 
 class World(Task):
+    def save_image(self, input, result_dir, filename, logger):
+        img = world.sample2img(self.train_input.to("cpu"), self.height, self.width)
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+        logger(f"wrote {image_name}")
+
     def __init__(
         self,
         nb_train_samples,
         nb_test_samples,
         batch_size,
+        result_dir=None,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -2141,6 +2148,11 @@ class World(Task):
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
+        if result_dir is not None:
+            self.save_image(
+                self.train_input[:96], result_dir, f"world_train.png", logger
+            )
+
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -2200,7 +2212,7 @@ class World(Task):
 
         ##############################
 
-        input, ar_mask = self.test_input[:64], self.test_ar_mask[:64]
+        input, ar_mask = self.test_input[:96], self.test_ar_mask[:96]
         result = input.clone() * (1 - ar_mask)
 
         masked_inplace_autoregression(
@@ -2213,10 +2225,17 @@ class World(Task):
             device=self.device,
         )
 
-        img = world.sample2img(result.to("cpu"), self.height, self.width)
-        image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
-        logger(f"wrote {image_name}")
+        self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger)
+
+    def store_new_problems(self, new_problems):
+        nb_current = self.train_input.size(0)
+        nb_new = new_problems.size(0)
+        if nb_new >= nb_current:
+            self.train_input[...] = new_problems[:nb_current]
+        else:
+            nb_kept = nb_current - nb_new
+            self.train_input[:nb_kept] = self.train_input[-nb_kept:].clone()
+            self.train_input[nb_kept:] = new_problems
 
     def create_new_problems(self, n_epoch, result_dir, logger, nb, model, nb_runs):
         new_problems = torch.empty(
@@ -2234,11 +2253,6 @@ class World(Task):
             device=self.device,
         )
 
-        img = world.sample2img(new_problems[:64].to("cpu"), self.height, self.width)
-        image_name = os.path.join(result_dir, f"world_new_{n_epoch:04d}.png")
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
-        logger(f"wrote {image_name}")
-
         nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64)
 
         for n in tqdm.tqdm(