Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 20:51:17 +0000 (22:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 20:51:17 +0000 (22:51 +0200)
grids.py
main.py
quiz_machine.py

index ee3a1e6..25bbc80 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -138,25 +138,30 @@ class Grids(problem.Problem):
         return struct
 
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+        if torch.is_tensor(quizzes):
+            return self.reconfigure([quizzes])[0]
+
         S = self.height * self.width
-        result = quizzes.new(quizzes.size())
+        result = [x.new(x.size()) for x in quizzes]
 
-        struct_from = self.get_structure(quizzes[:1])
-        i = self.indices_select(quizzes, struct_from)
+        struct_from = self.get_structure(quizzes[0][:1])
+        i = self.indices_select(quizzes[0], struct_from)
 
         sf = dict((l, n) for n, l in enumerate(struct_from))
 
-        q = quizzes.reshape(-1, 4, S + 1)[i]
-
-        result[i, 0 * (S + 1) : 1 * (S + 1)] = q[:, sf[struct[0]], :]
-        result[i, 1 * (S + 1) : 2 * (S + 1)] = q[:, sf[struct[1]], :]
-        result[i, 2 * (S + 1) : 3 * (S + 1)] = q[:, sf[struct[2]], :]
-        result[i, 3 * (S + 1) : 4 * (S + 1)] = q[:, sf[struct[3]], :]
+        for q in range(4):
+            k = sf[struct[q]]
+            for x, y in zip(quizzes, result):
+                l = x.size(1) // 4
+                y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
 
         j = i == False
 
         if j.any():
-            result[j] = self.reconfigure(quizzes[j], struct=struct)
+            for z, y in zip(
+                self.reconfigure([x[j] for x in quizzes], struct=struct), result
+            ):
+                y[j] = z
 
         return result
 
@@ -303,6 +308,7 @@ class Grids(problem.Problem):
         margin=8,
     ):
         quizzes = quizzes.to("cpu")
+        self.check_structure(quizzes, ("A", "f_A", "B", "f_B"))
 
         S = self.height * self.width
 
@@ -339,8 +345,9 @@ class Grids(problem.Problem):
                 colors = (
                     predicted_parts[:, :, None]
                     * (
-                        correct_parts[:, :, None] * green[None, None, :]
-                        + (1 - correct_parts[:, :, None]) * red[None, None, :]
+                        (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+                        + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+                        + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
                     )
                     + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
@@ -1321,21 +1328,18 @@ class Grids(problem.Problem):
                     X[i1:i2, j1:j2] = c[n]
                     f_X[i1:i2, j1:j2] = c[n]
 
-                while True:
-                    i1, i2 = torch.randint(self.height, (2,))
-                    j1, j2 = torch.randint(self.width, (2,))
-                    if (
-                        abs(i1 - i2) + abs(j1 - j2) > 2
-                        and X[i1, j1] == 0
-                        and X[i2, j2] == 0
-                    ):
-                        break
-
-                d2 = self.compdist(X, i2, j2)
-                d = self.compdist(X, i1, j1)
+                i1, i2 = torch.randint(self.height, (2,))
+                j1, j2 = torch.randint(self.width, (2,))
+                if (
+                    abs(i1 - i2) + abs(j1 - j2) > 2
+                    and X[i1, j1] == 0
+                    and X[i2, j2] == 0
+                ):
+                    d2 = self.compdist(X, i2, j2)
+                    d = self.compdist(X, i1, j1)
 
-                if d2[i1, j1] < 2 * self.width:
-                    break
+                    if d2[i1, j1] < 2 * self.width:
+                        break
 
             m = ((d + d2) == d[i2, j2]).long()
             f_X[...] = m * c[-1] + (1 - m) * f_X
@@ -1491,32 +1495,22 @@ if __name__ == "__main__":
     nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
-    for t in grids.all_tasks:
-        # for t in [grids.task_replace_color]:
-        # for t in [grids.task_symbols]:
+    # for t in grids.all_tasks:
+    for t in [grids.task_path]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
-        # predicted_parts = quizzes.new_zeros(quizzes.size(0), 4)
-        # predicted_parts[:, 3] = torch.randint(
-        # 2, (quizzes.size(0),), device=quizzes.device
-        # )
-        # predicted_parts[:, :3] = 1 - predicted_parts[:, 3:]
-        # correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device)
-        # correct_parts[:, 1:2] = correct_parts[:, :1]
         grids.save_quizzes_as_image(
             "/tmp",
             t.__name__ + ".png",
             quizzes,
-            # predicted_parts=predicted_parts,
-            # correct_parts=correct_parts,
         )
 
     # exit(0)
 
     nb = 1000
 
-    for t in grids.all_tasks:
-        # for t in [grids.task_compute]:
+    for t in grids.all_tasks:
+    for t in [grids.task_path]:
         start_time = time.perf_counter()
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index deba848..fa33b4e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,7 +89,7 @@ parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=1)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
 
 parser.add_argument("--proba_understands", type=float, default=0.9)
 
@@ -99,7 +99,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=0.75)
 
-parser.add_argument("--nb_rounds", type=int, default=3)
+parser.add_argument("--nb_rounds", type=int, default=1)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
@@ -549,7 +549,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
                 v = " ".join([str(n.item()) for n in r])
                 f.write(f"{n}: {v}\n")
 
-        quiz_machine.save_quizzes_as_image(args.result_dir, prefix, vq)
+        vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
+        quiz_machine.problem.save_quizzes_as_image(args.result_dir, prefix, vq)
 
 
 ######################################################################
@@ -724,6 +725,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             temperature_cold=args.temperature_cold,
         )
 
+        c_quizzes = quiz_machine.problem.reconfigure(
+            c_quizzes, ("A", "f_A", "B", "f_B")
+        )
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir,
             f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
index 749ae8b..8f14fa0 100755 (executable)
@@ -200,7 +200,7 @@ class QuizMachine:
             device=self.device,
         )
 
-        correct = (result == quizzes).min(dim=1).values
+        correct = (result == quizzes).min(dim=1).values.long()
 
         return result, correct
 
@@ -213,7 +213,7 @@ class QuizMachine:
     ):
         input = input.to(self.device)
         result = input.new(input.size())
-        correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool)
+        correct = input.new(input.size(0))
         predicted_parts = input.new(input.size(0), 4)
         nb = 0
         for struct, mask in [
@@ -226,19 +226,40 @@ class QuizMachine:
             result[i], correct[i] = self.predict(
                 model=model, quizzes=input[i], struct=struct, mask=mask
             )
+
             predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :]
+            correct[i] = (2 * correct[i] - 1) * (
+                predicted_parts[i].sum(dim=-1) == 1
+            ).long()
 
         assert nb == input.size(0)
 
-        main_test_accuracy = correct.sum() / correct.size(0)
+        nb_correct = (correct == 1).long().sum()
+        nb_total = (correct != 0).long().sum()
+        self.logger(
+            f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+        )
+
+        main_test_accuracy = nb_correct / nb_total
 
         ##############################
 
+        correct_parts = predicted_parts * correct[:, None]
+
+        result = result[:128]
+        predicted_parts = predicted_parts[:128]
+        correct_parts = correct_parts[:128]
+
+        self.problem.reconfigure(
+            [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B")
+        )
+
         self.problem.save_quizzes_as_image(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
-            quizzes=result[:128],
+            quizzes=result,
             predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
         )
 
         return main_test_accuracy
@@ -437,11 +458,6 @@ class QuizMachine:
         lt_noisy = lambda s, logits: logits / temperature_hot
         lt_clean = lambda s, logits: logits / temperature_cold
 
-        # lt_noisy = lambda s, logits: logits / (
-        # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long()
-        # )
-        # lt_clean = None
-
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,