Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 16:22:50 +0000 (18:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 16:22:50 +0000 (18:22 +0200)
grids.py
main.py
quiz_machine.py

index 37ed6a0..5ddcf32 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -131,7 +131,8 @@ class Grids(problem.Problem):
     def get_structure(self, quizzes):
         S = self.height * self.width
         struct = tuple(
-            self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0]
+            self.tok2l[n.item()]
+            for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
         )
         self.check_structure(quizzes, struct)
         return struct
@@ -143,8 +144,8 @@ class Grids(problem.Problem):
         sf = dict((l, n) for n, l in enumerate(struct_from))
 
         result = quizzes.new(quizzes.size())
-        q = quizzes.reshape(-1, 4, S + 1)
-        r = result.reshape(-1, 4, S + 1)
+        q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+        r = result.reshape(result.size(0), 4, S + 1)
 
         r[:, 0] = q[:, sf[struct[0]], :]
         r[:, 1] = q[:, sf[struct[1]], :]
@@ -153,12 +154,21 @@ class Grids(problem.Problem):
 
         return result
 
+    def non_trivial(self, quizzes):
+        S = self.height * self.width
+        assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
+        a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+        return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min(
+            dim=1
+        ).values
+
     def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
-        assert check_structure(quizzes, struct)
+        assert self.check_structure(quizzes, struct)
 
         ar_mask = quizzes.new_zeros(quizzes.size())
 
-        a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+        S = self.height * self.width
+        a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
         a[:, 0, :] = mask[0]
         a[:, 1, :] = mask[1]
         a[:, 2, :] = mask[2]
@@ -168,7 +178,7 @@ class Grids(problem.Problem):
 
     def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
-        q = quizzes.reshape(-1, 4, S + 1)
+        q = quizzes.reshape(quizzes.size(0), 4, S + 1)
         return (
             (q[:, 0, 0] == self.l2tok[struct[0]])
             & (q[:, 1, 0] == self.l2tok[struct[1]])
@@ -286,11 +296,13 @@ class Grids(problem.Problem):
         nrow=4,
         margin=8,
     ):
+        quizzes = quizzes.to("cpu")
+
         S = self.height * self.width
 
         A, f_A, B, f_B = (
-            quizzes.reshape(-1, 4, S + 1)[:, :, 1:]
-            .reshape(-1, 4, self.height, self.width)
+            quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+            .reshape(quizzes.size(0), 4, self.height, self.width)
             .permute(1, 0, 2, 3)
         )
 
@@ -1382,14 +1394,16 @@ class Grids(problem.Problem):
     def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
         quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
-        quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0])
-        quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1])
-        quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2])
-        quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3])
+        quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
+        quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
+        quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
+        quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
 
         return quizzes
 
     def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+        S = self.height * self.width
+
         if tasks is None:
             tasks = self.all_tasks
 
diff --git a/main.py b/main.py
index fcca116..deba848 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -324,7 +324,7 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
 
@@ -355,7 +355,6 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
             model=model,
             input=full_input[:2000],
             result_dir=args.result_dir,
-            deterministic_synthesis=deterministic_synthesis,
         )
 
 
@@ -408,7 +407,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
 
-    run_tests(model, quiz_machine, deterministic_synthesis=False)
+    run_tests(model, quiz_machine)
 
     threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
     threshold = threshold[threshold.size(0) // 2]
@@ -455,7 +454,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # We discard the trivial ones, according to a criterion
         # specific to the world quizzes (e.g. B=f(B))
 
-        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+        c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)]
 
         # We go through nb_rounds rounds and keep only quizzes on
         # which
@@ -471,6 +470,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         remains = [c_quizzes.size(0)]
 
         for r in range(args.nb_rounds):
+            if c_quizzes.size(0) == 0:
+                break
+
             number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
 
             nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
@@ -487,9 +489,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
             remains.append(c_quizzes.size(0))
 
-            if c_quizzes.size(0) == 0:
-                break
-
         if c_quizzes.size(0) > 0:
             nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
             recorded_validated.append(c_quizzes)
@@ -550,9 +549,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
                 v = " ".join([str(n.item()) for n in r])
                 f.write(f"{n}: {v}\n")
 
-        quiz_machine.save_quiz_illustrations(
-            args.result_dir, prefix, vq, show_part_to_predict=False
-        )
+        quiz_machine.save_quizzes_as_image(args.result_dir, prefix, vq)
 
 
 ######################################################################
@@ -727,8 +724,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             temperature_cold=args.temperature_cold,
         )
 
-        quiz_machine.save_quiz_illustrations(
-            args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
+            c_quizzes,
         )
 
     # Renew the training samples
index bb62181..2fb196c 100755 (executable)
@@ -32,6 +32,9 @@ def one_batch_masked_inplace_autoregression(
     logit_transformer=None,
     deterministic_synthesis=False,
 ):
+    if input.size(0) == 0:
+        return
+
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
     if to_generate.min() > 0:
@@ -174,11 +177,11 @@ class QuizMachine:
 
     ######################################################################
 
-    def predict(self, input, struct, mask):
+    def predict(self, model, quizzes, struct, mask):
         ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
-        seq_logproba = torch.empty(fwd_quizzes, device=self.device)
+        seq_logproba = torch.empty(quizzes.size(0), device=self.device)
 
         masked_inplace_autoregression(
             model=model,
@@ -186,50 +189,47 @@ class QuizMachine:
             input=result,
             ar_mask=ar_mask,
             seq_logproba=seq_logproba,
-            deterministic_synthesis=deterministic_synthesis,
+            deterministic_synthesis=False,
             progress_bar_desc="accuracy",
             device=self.device,
         )
 
-        nb_correct = (result == quizzes).min(dim=1).long()
+        correct = (result == quizzes).min(dim=1).values
 
         return result, correct
 
     def produce_results(
-        self, n_epoch, model, input, result_dir, deterministic_synthesis
+        self,
+        n_epoch,
+        model,
+        input,
+        result_dir,
     ):
         input = input.to(self.device)
-        i = self.problem.indices_select(quizzes=input, struct=struct)
-
-        input_fwd = input[i]
-        test_result_fwd, test_correct_fwd = predict(
-            input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-        )
-
-        input_bck = self.problem.reconfigure(
-            predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0],
-            struct=("A", "f_A", "B", "f_B"),
-        )
-
-        l = input_bck.size(1) // 4
-        input_bck[:, 3 * l :] = input[i == False][:, :l]
+        result = input.new(input.size())
+        correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool)
+
+        nb = 0
+        for struct, mask in [
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
+            (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)),
+        ]:
+            i = self.problem.indices_select(quizzes=input, struct=struct)
+            nb += i.long().sum()
+            result[i], correct[i] = self.predict(
+                model=model, quizzes=input[i], struct=struct, mask=mask
+            )
 
-        test_result_bck, test_correct_bck = predict(
-            input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-        )
+        assert nb == input.size(0)
 
-        main_test_accuracy = test_correct.sum() / test_correct.size(0)
+        main_test_accuracy = correct.sum() / correct.size(0)
 
         ##############################
 
-        test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0)
-        test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0)
-
-        self.save_quiz_illustrations(
+        self.problem.save_quizzes_as_image(
             result_dir,
-            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=test_result,
-            # mistakes=test_correct,
+            f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
+            quizzes=result[:128],
         )
 
         return main_test_accuracy
@@ -355,12 +355,19 @@ class QuizMachine:
 
         seq_logproba[...] = 0.0
 
+        c_quizzes = c_quizzes.to(self.device)
+        print(self.problem.get_structure(c_quizzes))
+        reversed_c_quizzes = self.problem.reconfigure(
+            c_quizzes, ("f_A", "A", "f_B", "B")
+        )
+
         for model in models_for_validation:
             # A, f(A), B | f(B)
-            c_quizzes = c_quizzes.to(self.device)
             result = c_quizzes.clone()
 
-            ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+            ar_mask = self.problem.make_ar_mask(
+                result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+            )
 
             masked_inplace_autoregression(
                 model=model,
@@ -377,10 +384,11 @@ class QuizMachine:
             # -------------------------------
 
             # f(A), A, f(B) | B
-            c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device)
-            result = c_quizzes.clone()
+            result = reversed_c_quizzes.clone()
 
-            ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+            ar_mask = self.problem.make_ar_mask(
+                result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
+            )
 
             masked_inplace_autoregression(
                 model=model,
@@ -392,7 +400,7 @@ class QuizMachine:
                 device=self.device,
             )
 
-            correct *= (c_quizzes == result).long().min(dim=-1).values
+            correct *= (reversed_c_quizzes == result).long().min(dim=-1).values
 
             # -------------------------------
 
@@ -409,11 +417,8 @@ class QuizMachine:
         temperature_hot=1.0,
         temperature_cold=1.0,
     ):
-        c_quizzes = torch.empty(
-            nb,
-            self.problem.seq_len,
-            device=self.device,
-            dtype=torch.int64,
+        c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to(
+            self.device
         )
 
         seq_logproba = torch.zeros(nb, device=self.device)
@@ -426,13 +431,13 @@ class QuizMachine:
         # )
         # lt_clean = None
 
-        c_quizzes[...] = self.problem.token_backward
-
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+            ar_mask=self.problem.make_ar_mask(
+                c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
+            ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_noisy,
             deterministic_synthesis=False,
@@ -443,20 +448,24 @@ class QuizMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            ar_mask=self.problem.make_ar_mask(
+                c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
+            ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
             deterministic_synthesis=False,
             device=self.device,
         )
 
-        c_quizzes = self.problem.p_a_flip(c_quizzes)
+        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
 
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            ar_mask=self.problem.make_ar_mask(
+                c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
             deterministic_synthesis=False,