Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 9 Sep 2024 18:58:04 +0000 (20:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 9 Sep 2024 18:58:04 +0000 (20:58 +0200)
grids.py
main.py

index 73e722e..054ba35 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm, os, warnings, cairo
+import math, sys, tqdm, os, warnings, cairo, re
 
 import torch, torchvision
 
@@ -234,6 +234,51 @@ class Grids(problem.Problem):
 
         return ar_mask
 
+    def text2quiz(self, t):
+        chr2col = [
+            (".", "white"),
+            ("r", "red"),
+            ("g", "green"),
+            ("b", "blue"),
+            ("y", "yellow"),
+            ("c", "cyan"),
+            ("v", "violet"),
+            ("l", "lightgreen"),
+            ("o", "brown"),
+            ("l", "lightblue"),
+            ("a", "gray"),
+        ]
+
+        col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+        chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+        t = re.sub(r"#.*\n", "", t).strip()
+        l = t.replace("\n\n", ";").split(";")
+
+        result = []
+
+        for t in l:
+            t = "".join(t.replace("\n", " ").strip().split(" "))
+            t = torch.tensor([chr2tok[c] for c in t])
+            t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+            t = torch.cat(
+                [
+                    torch.tensor(
+                        [
+                            [self.token_A],
+                            [self.token_f_A],
+                            [self.token_B],
+                            [self.token_f_B],
+                        ]
+                    ),
+                    t,
+                ],
+                dim=1,
+            )
+            result.append(t.flatten()[None, :])
+
+        return torch.cat(result, dim=0)
+
     def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
         q = quizzes.reshape(quizzes.size(0), 4, S + 1)
@@ -1798,41 +1843,65 @@ if __name__ == "__main__":
 
     grids = Grids()
 
-    # nb = 5
-    # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
-    # print(quizzes)
-    # print(grids.get_order(quizzes))
-    # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
-    # print("DEBUG2", quizzes)
-    # print(grids.get_order(quizzes))
-    # print(quizzes)
-
-    # i = torch.rand(quizzes.size(0)) < 0.5
-
-    # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
-
-    # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
-
-    # print(
-    # i.equal(j),
-    # grids.get_order(quizzes[j]),
-    # grids.get_order(quizzes[j == False]),
-    # )
-
-    #   exit(0)
-
-    # nb = 1000
-    # grids = problem.MultiThreadProblem(
-    # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
-    # )
-    #    time.sleep(10)
-    # start_time = time.perf_counter()
-    # prompts, answers = grids.generate_w_quizzes(nb)
-    # delay = time.perf_counter() - start_time
-    # print(f"{prompts.size(0)/delay:02f} seq/s")
-    # exit(0)
-
-    # if True:
+    q = grids.text2quiz(
+        """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+    )
+
+    grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1)
+
+    exit(0)
+
     nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
diff --git a/main.py b/main.py
index d914113..97d37ce 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -129,6 +129,8 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
 
+parser.add_argument("--quizzes", type=str, default=None)
+
 ######################################################################
 
 grids_tasks = ", ".join(
@@ -1060,11 +1062,9 @@ def save_badness_statistics(
 ######################################################################
 
 
-def quiz_validation_paris(models, c_quizzes, local_device):
+def quiz_validation(models, c_quizzes, local_device):
     nb_have_to_be_correct = 3
-    nb_have_to_be_wrong = 1
-
-    nb_runs = 3
+    nb_have_to_be_wrong = 3
     nb_mistakes_to_be_wrong = 5
 
     record_wrong = []
@@ -1073,47 +1073,22 @@ def quiz_validation_paris(models, c_quizzes, local_device):
     for i, model in enumerate(models):
         assert i == model.id  # a bit of paranoia
         model = copy.deepcopy(model).to(local_device).eval()
-        correct, wrong = True, False
-        for _ in range(nb_runs):
-            n = model_ae_argmax_nb_mistakes(model, c_quizzes).long()
-            correct = correct & (n == 0)
-            wrong = wrong | (n >= nb_mistakes_to_be_wrong)
-        record_wrong.append(wrong[:, None])
-        nb_correct += correct.long()
-        nb_wrong += wrong.long()
-
-    # print("nb_correct", nb_correct)
-
-    # print("nb_wrong", nb_wrong)
-
-    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
-    wrong = torch.cat(record_wrong, dim=1)
 
-    return to_keep, wrong
-
-
-def quiz_validation_berne(models, c_quizzes, local_device):
-    nb_have_to_be_correct = 3
-    nb_have_to_be_wrong = 1
-    nb_runs = 3
-
-    record_wrong = []
-    nb_correct, nb_wrong = 0, 0
+        correct, wrong = True, False
 
-    for i, model in enumerate(models):
-        assert i == model.id  # a bit of paranoia
-        model = copy.deepcopy(model).to(local_device).eval()
-        log_probas = 0
-        for _ in range(nb_runs):
-            log_probas += model_ae_proba_solutions(
-                model, c_quizzes, log_probas=True, reduce=False
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-        probas = log_probas.exp()
-        correct = (probas <= 0.75).long().sum(dim=1) == 0
-        wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & (
-            log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5
-        )
+            result = ae_generate(
+                model,
+                (1 - mask_generate) * c_quizzes,
+                mask_generate,
+            )
+            nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+            correct = correct & (nb_mistakes == 0)
+            wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+
         record_wrong.append(wrong[:, None])
         nb_correct += correct.long()
         nb_wrong += wrong.long()
@@ -1125,6 +1100,9 @@ def quiz_validation_berne(models, c_quizzes, local_device):
     return to_keep, wrong
 
 
+######################################################################
+
+
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
 
@@ -1137,10 +1115,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
         nb=args.inference_batch_size, quad_order=quad_order
     ).to(local_device)
 
-    mask_generate = quiz_machine.make_quiz_mask(
-        quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
-    )
-
     wanted_nb = nb
     nb_to_save = 256
     nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
@@ -1155,15 +1129,31 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
             generator_id = model.id
 
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
+            )
+
             c_quizzes = ae_generate(model, template, mask_generate)
 
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+            for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
+                mask_generate = quiz_machine.make_quiz_mask(
+                    quizzes=c_quizzes,
+                    quad_order=("A", "f_A", "B", "f_B"),
+                    quad_mask=quad,
+                )
+                c_quizzes = ae_generate(
+                    model,
+                    (1 - mask_generate) * c_quizzes,
+                    mask_generate,
+                )
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                to_keep, record_wrong = quiz_validation_berne(
-                    models, c_quizzes, local_device
-                )
+                to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
                 q = c_quizzes[to_keep]
 
                 if q.size(0) > 0:
@@ -1199,51 +1189,36 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
         c_quizzes = torch.cat(record_c_quizzes, dim=0)
         agreements = torch.cat(record_agreements, dim=0)
 
-        subset_c_quizzes = c_quizzes[:nb_to_save]
-
-        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-        # for model in models:
-        # for r in range(3):
-        # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png"
-        # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes)
-        # quiz_machine.problem.save_quizzes_as_image(
-        # args.result_dir,
-        # filename,
-        # quizzes=p,
-        # delta=True,
-        # nrow=8,
-        # )
-        # log_string(f"wrote {filename}")
-        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    return c_quizzes, agreements
 
-        filename = f"culture_c_quiz_{n_epoch:04d}.png"
 
-        l = [
-            model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
-            for model in models
-        ]
-        probas = torch.cat([x[:, None] for x in l], dim=1)
-        comments = []
-
-        for l in probas:
-            comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            filename,
-            quizzes=subset_c_quizzes,
-            comments=comments,
-            delta=True,
-            nrow=8,
-        )
+def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
+    record.append(generate_ae_c_quizzes(models, nb, local_device))
 
-        log_string(f"wrote {filename}")
 
-    return c_quizzes, agreements
+######################################################################
 
 
-def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
-    record.append(generate_ae_c_quizzes(models, nb, local_device))
+def save_c_quizzes_with_scores(models, c_quizzes, filename):
+    l = [model_ae_proba_solutions(model, c_quizzes) for model in models]
+
+    probas = torch.cat([x[:, None] for x in l], dim=1)
+
+    comments = []
+
+    for l in probas:
+        comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=subset_c_quizzes,
+        comments=comments,
+        delta=True,
+        nrow=8,
+    )
+
+    log_string(f"wrote {filename}")
 
 
 ######################################################################
@@ -1284,6 +1259,47 @@ nb_parameters = sum(p.numel() for p in models[0].parameters())
 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 
+######################################################################
+
+if args.quizzes is not None:
+    with open(args.quizzes, "r") as file:
+        txt = file.read()
+
+    quizzes = quiz_machine.problem.text2quiz(txt)
+
+    record = []
+
+    quizzes = quizzes.to(main_device)
+    for model in models:
+        log_string(f"processing {model.id} {args.quizzes}")
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+            )
+            result = ae_generate(
+                model,
+                (1 - mask_generate) * quizzes,
+                mask_generate,
+            )
+            record.append(result)
+
+    result = torch.cat(record, dim=0)
+
+    filename = "result.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=result,
+        delta=True,
+        nrow=8,
+    )
+
+    log_string(f"wrote {filename}")
+
+    exit(0)
+
+
 ######################################################################
 
 last_n_epoch_c_quizzes = 0
@@ -1374,6 +1390,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         # --------------------------------------------------------------------
 
+        filename = f"culture_c_quiz_{n_epoch:04d}.png"
+        save_c_quizzes_with_scores(models, c_quizzes[:128], filename)
+
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")
 
         time_train = 0