Update.
[culture.git] / tasks.py
index 622cd56..0345bd0 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2244,7 +2244,9 @@ class World(Task):
             input[:nb_kept] = input[-nb_kept:].clone()
             input[nb_kept:] = new_quizzes
 
             input[:nb_kept] = input[-nb_kept:].clone()
             input[nb_kept:] = new_quizzes
 
-    def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs):
+    def create_new_quizzes(
+        self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs
+    ):
         new_quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
         )
         new_quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
         )
@@ -2274,15 +2276,18 @@ class World(Task):
             .expand_as(result)
         )
 
             .expand_as(result)
         )
 
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis=False,
-            progress_bar_desc=None,
-            device=self.device,
-        )
+        dispatch = torch.randint(len(other_models), (result.size(0),))
+
+        for n, m in enumerate(other_models):
+            masked_inplace_autoregression(
+                m,
+                self.batch_size,
+                result[dispatch == n],
+                ar_mask[dispatch == n],
+                deterministic_synthesis=False,
+                progress_bar_desc=None,
+                device=self.device,
+            )
 
         nb_correct = (
             (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
 
         nb_correct = (
             (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)