Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 26 Jul 2024 07:40:04 +0000 (09:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 26 Jul 2024 07:40:04 +0000 (09:40 +0200)
grids.py
main.py
quiz_machine.py

index 67a5c97..3453d4a 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -336,12 +336,14 @@ class Grids(problem.Problem):
         predicted_parts=None,
         correct_parts=None,
         comments=None,
-        comment_height=64,
+        comment_height=48,
         nrow=4,
-        margin=8,
+        margin=8,ff
     ):
         quizzes = quizzes.to("cpu")
-        self.check_structure(quizzes, ("A", "f_A", "B", "f_B"))
+
+        if not self.check_structure(quizzes, ("A", "f_A", "B", "f_B")):
+            print(f"**WARNING** {filename} is not in A/f_A/B/f_B order")
 
         S = self.height * self.width
 
@@ -1490,6 +1492,34 @@ class Grids(problem.Problem):
             X[i2:ii, jj1:jj2] = c[4]
             f_X[i2:ii, jj1:jj2] = c[4]
 
+    def task_science_dot(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                i, j = (
+                    torch.randint(self.height, (1,)).item(),
+                    torch.randint(self.width, (1,)).item(),
+                )
+                q = 0
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+                    if i >= i1 and i < i2:
+                        q += 1
+                        f_X[i, j1:j2] = c[-1]
+                    if j >= j1 and j < j2:
+                        q += 1
+                        f_X[i1:i2, j] = c[-1]
+                X[i, j] = c[-1]
+                f_X[i, j] = c[-1]
+                if q >= 2:
+                    break
+
     # end_tasks
 
     ######################################################################
@@ -1529,13 +1559,13 @@ class Grids(problem.Problem):
 
         return quizzes
 
-    def save_some_examples(self, result_dir):
+    def save_some_examples(self, result_dir, prefix=""):
         nb, nrow = 128, 4
         for t in self.all_tasks:
             print(t.__name__)
             quizzes = self.generate_w_quizzes_(nb, tasks=[t])
             self.save_quizzes_as_image(
-                result_dir, t.__name__ + ".png", quizzes, nrow=nrow
+                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
             )
 
 
@@ -1588,7 +1618,7 @@ if __name__ == "__main__":
 
     # for t in grids.all_tasks:
 
-    for t in [grids.task_science_implicit]:
+    for t in [grids.task_science_dot]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
diff --git a/main.py b/main.py
index 4d618cc..4310e6b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -95,7 +95,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--temperature_hot", type=float, default=1.5)
+parser.add_argument("--temperature_hot", type=float, default=2)
 
 parser.add_argument("--temperature_cold", type=float, default=0.75)
 
@@ -323,6 +323,9 @@ elif args.problem == "grids":
             tasks=args.grids_science_tasks,
         )
         science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples)
+        if not args.resume:
+            problem.save_some_examples(args.result_dir, "science_")
+
 
 else:
     raise ValueError
@@ -447,6 +450,71 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 ######################################################################
 
 
+def save_additional_results(models, science_w_quizzes):
+    for model in models:
+        c_quizzes = quiz_machine.generate_c_quizzes(
+            128,
+            model_for_generation=model,
+            temperature_hot=args.temperature_hot,
+            temperature_cold=args.temperature_cold,
+        )
+
+        c_quizzes = quiz_machine.problem.reconfigure(
+            c_quizzes, ("A", "f_A", "B", "f_B")
+        )
+
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
+            c_quizzes,
+        )
+
+    ######################################################################
+
+    if science_w_quizzes is not None:
+        for model in models:
+            struct = ("A", "f_A", "B", "f_B")
+            mask = (0, 0, 0, 1)
+            result, correct = quiz_machine.predict(
+                model=model,
+                quizzes=science_w_quizzes.to(main_device),
+                struct=struct,
+                mask=mask,
+            )
+
+            predicted_parts = torch.tensor(mask, device=correct.device)[None, :]
+            correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
+
+            nb_correct = (correct == 1).long().sum()
+            nb_total = (correct != 0).long().sum()
+
+            log_string(
+                f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+            )
+
+            i = correct == 1
+            j = correct != 1
+
+            result = torch.cat([result[i], result[j]], dim=0)
+            correct = torch.cat([correct[i], correct[j]], dim=0)
+            correct_parts = predicted_parts * correct[:, None]
+
+            result = result[:128]
+            predicted_parts = predicted_parts[:128]
+            correct_parts = correct_parts[:128]
+
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
+                quizzes=result,
+                predicted_parts=predicted_parts,
+                correct_parts=correct_parts,
+            )
+
+
+######################################################################
+
+
 def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
     nb_to_validate = nb_for_train + nb_for_test
     nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
@@ -562,7 +630,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     if vq.size(0) > 0:
         number_correct_responses = 0
-        for r in range(args.nb_rounds):
+        for r in range(10):
             number_correct_responses += quiz_machine.models_successes(models, vq)
 
         comments = []
@@ -740,39 +808,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
-    for model in weakest_models:
-        c_quizzes = quiz_machine.generate_c_quizzes(
-            128,
-            model_for_generation=model,
-            temperature_hot=args.temperature_hot,
-            temperature_cold=args.temperature_cold,
-        )
-
-        c_quizzes = quiz_machine.problem.reconfigure(
-            c_quizzes, ("A", "f_A", "B", "f_B")
-        )
-
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
-            c_quizzes,
-        )
-
-    ######################################################################
-
-    if science_w_quizzes is not None:
-        result, correct = quiz_machine.predict(
-            model=model,
-            quizzes=science_w_quizzes.to(main_device),
-            struct=("A", "f_A", "B", "f_B"),
-            mask=(0, 0, 0, 1),
-        )
-
-        nb_correct = (correct == 1).long().sum()
-        nb_total = (correct != 0).long().sum()
-        log_string(
-            f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
-        )
+    save_additional_results()
 
     ######################################################################
 
index 8e40921..a9319c7 100755 (executable)
@@ -132,7 +132,10 @@ class QuizMachine:
         self.train_struct = [
             ("A", "f_A", "B", "f_B"),  # The standard order
             ("f_A", "A", "f_B", "B"),  # The reverse order for validation
+            ("B", "f_B", "A", "f_A"),
+            ("f_B", "B", "f_A", "A"),
             ("f_B", "f_A", "A", "B"),  # The synthesis order
+            ("f_B", "f_A", "A", "B"),  # twice!
         ]
 
         self.LOCK_C_QUIZZES = threading.Lock()
@@ -224,6 +227,8 @@ class QuizMachine:
         for struct, mask in [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
         ]:
             i = self.problem.indices_select(quizzes=input, struct=struct)
@@ -490,3 +495,33 @@ class QuizMachine:
         return c_quizzes.to("cpu")
 
     ######################################################################
+
+    def generate_c_quizzes_simple(
+        self,
+        nb,
+        model_for_generation,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
+    ):
+        c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
+        c_quizzes = c_quizzes.to(self.device)
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        lt_noisy = lambda s, logits: logits / temperature_hot
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(
+                c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1)
+            ),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            device=self.device,
+        )
+
+        return c_quizzes.to("cpu")
+
+    ######################################################################