Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 09:04:52 +0000 (11:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 09:04:52 +0000 (11:04 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 97c7130..b57c512 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -73,6 +73,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--check", action="store_true", default=False)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -183,6 +185,9 @@ for n in vars(args):
 
 ######################################################################
 
+if args.test:
+    args.nb_train_samples = 1000
+    args.nb_test_samples = 25
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
@@ -606,6 +611,13 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 accuracy_to_make_quizzes = 0.975
+nb_new_quizzes_for_train = 1000
+nb_new_quizzes_for_test = 100
+
+if args.test:
+    accuracy_to_make_quizzes = 0.0
+    nb_new_quizzes_for_train = 10
+    nb_new_quizzes_for_test = 10
 
 for n_epoch in range(args.nb_epochs):
     # select the model with lowest accuracy
@@ -634,8 +646,8 @@ for n_epoch in range(args.nb_epochs):
             model,
             other_models,
             task,
-            nb_for_train=1000,
-            nb_for_test=100,
+            nb_for_train=nb_new_quizzes_for_train,
+            nb_for_test=nb_new_quizzes_for_test,
         )
 
 
index 77493a8..50d541b 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -273,6 +273,34 @@ class World(Task):
                 device=self.device,
             )
 
-            nb_correct += (new_quizzes == result).long().min(dim=-1).values
+            l = self.height * self.width
+            direction = new_quizzes[:, l : l + 1]
+            direction = world.token_forward * (
+                direction == world.token_backward
+            ) + world.token_backward * (direction == world.token_forward)
+            inverted_quizzes = torch.cat(
+                [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
+            )
+
+            inverted_result = inverted_quizzes.clone()
+
+            masked_inplace_autoregression(
+                m,
+                self.batch_size,
+                inverted_result,
+                ar_mask,
+                deterministic_synthesis=True,
+                progress_bar_desc="solving reverse quizzes",
+                device=self.device,
+            )
+
+            nb_correct += (
+                (
+                    (new_quizzes == result).long()
+                    * (inverted_quizzes, inverted_result).long()
+                )
+                .min(dim=-1)
+                .values
+            )
 
         return new_quizzes, nb_correct