Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 14:09:51 +0000 (16:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 14:09:51 +0000 (16:09 +0200)
tasks.py

index f6d34a8..cb5900b 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -238,27 +238,46 @@ class World(Task):
         model,
         other_models,
     ):
-        new_quizzes = torch.empty(
+        ###############################################################
+        # Generate quizzes with model
+
+        quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
         )
-        ar_mask = torch.full(new_quizzes.size(), 1, device=self.device)
+        ar_mask = torch.full(quizzes.size(), 1, device=self.device)
 
         masked_inplace_autoregression(
             model,
             self.batch_size,
-            new_quizzes,
+            quizzes,
             ar_mask,
             deterministic_synthesis=False,
             progress_bar_desc="creating quizzes",
             device=self.device,
         )
 
-        ar_mask = self.make_ar_mask(new_quizzes)
+        ###############################################################
+        # Create the reverse quizzes
+
+        l = self.height * self.width
+        direction = quizzes[:, l : l + 1]
+        direction = world.token_forward * (
+            direction == world.token_backward
+        ) + world.token_backward * (direction == world.token_forward)
+        reverse_quizzes = torch.cat(
+            [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+        )
+
+        ar_mask = self.make_ar_mask(quizzes)
+
+        ###############################################################
+        # Check how many of the other models can solve them in both
+        # directions
 
         nb_correct = 0
 
         for m in other_models:
-            result = new_quizzes.clone()
+            result = quizzes.clone()
 
             masked_inplace_autoregression(
                 m,
@@ -270,29 +289,24 @@ class World(Task):
                 device=self.device,
             )
 
-            l = self.height * self.width
-            direction = new_quizzes[:, l : l + 1]
-            direction = world.token_forward * (
-                direction == world.token_backward
-            ) + world.token_backward * (direction == world.token_forward)
-            inverted_quizzes = torch.cat(
-                [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
-            )
+            correct = (quizzes == result).long().min(dim=-1).values
 
-            inverted_result = inverted_quizzes.clone()
+            reverse_result = reverse_quizzes.clone()
 
             masked_inplace_autoregression(
                 m,
                 self.batch_size,
-                inverted_result,
+                reverse_result,
                 ar_mask,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving reversed quizzes",
                 device=self.device,
             )
 
-            nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
-                inverted_quizzes == inverted_result
-            ).long().min(dim=-1).values
+            reverse_correct = (
+                (reverse_quizzes == reverse_result).long().min(dim=-1).values
+            )
+
+            nb_correct += correct * reverse_correct
 
-        return new_quizzes, nb_correct
+        return quizzes, nb_correct