Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 14:21:24 +0000 (16:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 14:21:24 +0000 (16:21 +0200)
grids.py
main.py
quiz_machine.py

index e1eff00..4db12db 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -118,7 +118,7 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
-    def make_ar_mask(self, quizzes, first=False):
+    def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"):
         S = self.height * self.width
 
         assert (
@@ -133,12 +133,17 @@ class Grids(problem.Problem):
 
         T = torch.arange(quizzes.size(1), device=quizzes.device)
 
-        if first:
+        if shape == "fwd_3_bck_123":
+            forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+            backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+        elif shape == "fwd_012_bck_0":
             forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
             backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
-        else:
+        elif shape == "fwd_3_bck_3":
             forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+            backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+        else:
+            raise ValueError(shape)
 
         is_forward = (quizzes[:, 0] == self.token_forward).long()
 
@@ -147,7 +152,7 @@ class Grids(problem.Problem):
             + (1 - is_forward)[:, None] * backward_mask[None, :]
         )
 
-    def p_a_flip(self, quizzes):
+    def p_a_flip(self, quizzes, pairwise_flip=False):
         S = self.height * self.width
 
         assert (
@@ -160,10 +165,26 @@ class Grids(problem.Problem):
             & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
         ).all()
 
-        flipped = torch.cat(
-            [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)],
-            dim=1,
-        )
+        if pairwise_flip:
+            flipped = torch.cat(
+                [
+                    quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+                    quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
+                    quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
+                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+                ],
+                dim=1,
+            )
+        else:
+            flipped = torch.cat(
+                [
+                    quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
+                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+                    quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+                    quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
+                ],
+                dim=1,
+            )
 
         m = (flipped[:, 0] == self.token_forward).long()
         flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
diff --git a/main.py b/main.py
index 0182e6a..562a95d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -425,10 +425,13 @@ def keep_good_quizzes(models, quizzes):
 
     elif args.c_quiz_validation_mode == "predict":
         nc = quiz_machine.solution_nb_correct(models, quizzes)
+
         count_nc = tuple(
             n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
         )
+
         log_string(f"nb_correct {count_nc}")
+
         to_keep = nc == (len(models) - 1)
 
     else:
index 046ab73..91eb3ac 100755 (executable)
@@ -293,7 +293,7 @@ class QuizMachine:
     def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
         def compute_accuracy(input, log_prefix=None):
             input = input.to(self.device)
-            ar_mask = self.problem.make_ar_mask(input)
+            ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
 
@@ -432,7 +432,7 @@ class QuizMachine:
                     c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
                 ):
                     input = input.to(self.device)
-                    ar_mask = self.problem.make_ar_mask(input)
+                    ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
@@ -448,10 +448,7 @@ class QuizMachine:
     ###############################################################
 
     def solution_nb_correct(
-        self,
-        models_for_validation,
-        c_quizzes,
-        deterministic_validation=False,
+        self, models_for_validation, c_quizzes, bidirectional_validation=True
     ):
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
@@ -464,10 +461,11 @@ class QuizMachine:
         seq_logproba[...] = 0.0
 
         for model in models_for_validation:
+            # A, f(A), B | f(B)
             c_quizzes = c_quizzes.to(self.device)
             result = c_quizzes.clone()
 
-            ar_mask = self.problem.make_ar_mask(result)
+            ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
 
             masked_inplace_autoregression(
                 model=model,
@@ -476,13 +474,38 @@ class QuizMachine:
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
                 temperature=1.0,
-                deterministic_synthesis=deterministic_validation,
+                deterministic_synthesis=False,
                 device=self.device,
             )
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            nb_correct += correct
+            # -------------------------------
+
+            # f(A), A, f(B) | B
+            c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to(
+                self.device
+            )
+            result = c_quizzes.clone()
+
+            ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+
+            masked_inplace_autoregression(
+                model=model,
+                batch_size=self.batch_size,
+                input=result,
+                ar_mask=ar_mask,
+                seq_logproba=seq_logproba[:, model.id],
+                temperature=1.0,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            flipped_correct = (c_quizzes == result).long().min(dim=-1).values
+
+            # -------------------------------
+
+            nb_correct += correct * flipped_correct
 
         return nb_correct.to("cpu")
 
@@ -512,7 +535,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
                 temperature=temperature_hot,
                 deterministic_synthesis=False,
@@ -523,7 +546,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,
@@ -537,7 +560,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
                 temperature=temperature_hot,
                 deterministic_synthesis=False,
@@ -548,7 +571,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,
@@ -561,7 +584,7 @@ class QuizMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes),
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
                 temperature=temperature_cold,
                 deterministic_synthesis=False,