Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 17:45:08 +0000 (19:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 17:45:08 +0000 (19:45 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index ca0d152..672dab5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -906,6 +906,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     # --------------------------------------------
 
     if test_accuracy >= 0.8:
+        nb_runs, nb_min_correct, nb_max_correct = 10, 8, 9
         nb_for_train, nb_for_test = 1000, 100
         kept = []
 
@@ -914,19 +915,23 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
                 n_epoch=n_epoch,
                 result_dir=args.result_dir,
                 logger=log_string,
-                nb=nb_required,
+                nb=4 * (nb_for_train + nb_for_test),
                 model=model,
-                nb_runs=10,
+                nb_runs=nb_runs,
             )
 
-            to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+            to_keep = new_quizzes[
+                torch.logical_and(
+                    nb_correct >= nb_min_correct, nb_correct <= nb_max_correct
+                )
+            ]
             log_string(f"keep {to_keep.size(0)} quizzes")
             kept.append(to_keep)
 
         new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
 
-        task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
-        task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+        task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
+        task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
 
         task.save_image(
             new_quizzes[:96],
index 49b83ec..622cd56 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2100,7 +2100,7 @@ import world
 
 class World(Task):
     def save_image(self, input, result_dir, filename, logger):
-        img = world.sample2img(self.train_input.to("cpu"), self.height, self.width)
+        img = world.sample2img(input.to("cpu"), self.height, self.width)
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
         logger(f"wrote {image_name}")
@@ -2226,7 +2226,9 @@ class World(Task):
             device=self.device,
         )
 
-        self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger)
+        self.save_image(
+            result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger
+        )
 
         return main_test_accuracy
 
@@ -2258,30 +2260,32 @@ class World(Task):
             device=self.device,
         )
 
-        nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64)
+        input = (
+            new_quizzes[:, None, :]
+            .expand(-1, nb_runs, -1)
+            .clone()
+            .reshape(-1, new_quizzes.size(-1))
+        )
+        result = input.clone()
 
-        for n in tqdm.tqdm(
-            range(new_quizzes.size(0)), dynamic_ncols=True, desc="checking quizzes"
-        ):
-            result = new_quizzes[n][None, :].expand(nb_runs, -1).clone()
-            ar_mask = (
-                (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
-                .long()[None, :]
-                .expand_as(result)
-            )
+        ar_mask = (
+            (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
+            .long()[None, :]
+            .expand_as(result)
+        )
 
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis=False,
-                progress_bar_desc=None,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis=False,
+            progress_bar_desc=None,
+            device=self.device,
+        )
 
-            nb_correct[n] = (
-                (new_quizzes[n][None, :] == result).long().min(dim=1).values.sum()
-            )
+        nb_correct = (
+            (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
+        )
 
         return new_quizzes, nb_correct