Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 19:27:44 +0000 (21:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 19:27:44 +0000 (21:27 +0200)
main.py
tasks.py
world.py

diff --git a/main.py b/main.py
index 22edf7b..4a1207d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -845,6 +845,7 @@ def run_tests(model, task, deterministic_synthesis):
 
 
 def create_quizzes(
+    model,
     other_models,
     task,
     nb_for_train=1000,
@@ -861,7 +862,8 @@ def create_quizzes(
             result_dir=args.result_dir,
             logger=log_string,
             nb=4 * (nb_for_train + nb_for_test),
-            models=other_models,
+            model=model,
+            other_models=other_models,
             nb_runs=nb_runs,
         )
 
index 622cd56..0345bd0 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2244,7 +2244,9 @@ class World(Task):
             input[:nb_kept] = input[-nb_kept:].clone()
             input[nb_kept:] = new_quizzes
 
-    def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs):
+    def create_new_quizzes(
+        self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs
+    ):
         new_quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
         )
@@ -2274,15 +2276,18 @@ class World(Task):
             .expand_as(result)
         )
 
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis=False,
-            progress_bar_desc=None,
-            device=self.device,
-        )
+        dispatch = torch.randint(len(other_models), (result.size(0),))
+
+        for n, m in enumerate(other_models):
+            masked_inplace_autoregression(
+                m,
+                self.batch_size,
+                result[dispatch == n],
+                ar_mask[dispatch == n],
+                deterministic_synthesis=False,
+                progress_bar_desc=None,
+                device=self.device,
+            )
 
         nb_correct = (
             (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
index 97c7b1d..43126d5 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -34,7 +34,7 @@ def generate(
     nb,
     height,
     width,
-    max_nb_obj=colors.size(0) - 2,
+    max_nb_obj=2,
     nb_iterations=2,
 ):
     f_start = torch.zeros(nb, height, width, dtype=torch.int64)