Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 14:54:46 +0000 (16:54 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 14:54:46 +0000 (16:54 +0200)
grids.py
main.py
quiz_machine.py

index 93b027a..1d94e07 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo
 
 import torch, torchvision
 
@@ -14,6 +14,36 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
+def text_img(height, width, text):
+    pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+    surface = cairo.ImageSurface.create_for_data(
+        pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+    )
+
+    ctx = cairo.Context(surface)
+    ctx.set_source_rgb(0, 0, 0)
+    ctx.set_font_size(16)
+    ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+    y = None
+    for line in text.split("\n"):
+        xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+        if y is None:
+            y = height * 1.5
+            x = height * 0.5
+
+        ctx.move_to(x, y)
+        ctx.show_text(line)
+        y += height * 1.5
+
+    ctx.stroke()
+
+    return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
 import problem
 
 
@@ -203,7 +233,8 @@ class Grids(problem.Problem):
         max_nb_cached_chunks=None,
         chunk_size=None,
         nb_threads=-1,
-        tasks=None,
+        world_tasks=None,
+        science_tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
 
@@ -233,7 +264,7 @@ class Grids(problem.Problem):
 
         self.cache_rec_coo = {}
 
-        all_tasks = [
+        self.all_tasks = [
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
@@ -254,10 +285,17 @@ class Grids(problem.Problem):
             # self.task_islands, # TOO MESSY
         ]
 
-        if tasks is None:
-            self.all_tasks = all_tasks
+        if world_tasks is None:
+            self.world_tasks = self.all_tasks
         else:
-            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+            self.world_tasks = [
+                getattr(self, "task_" + t) for t in world_tasks.split(",")
+            ]
+
+        if science_tasks is not None:
+            self.science_tasks = [
+                getattr(self, "task_" + t) for t in science_tasks.split(",")
+            ]
 
         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
@@ -305,6 +343,8 @@ class Grids(problem.Problem):
         quizzes,
         predicted_parts=None,
         correct_parts=None,
+        comments=None,
+        comment_height=64,
         nrow=4,
         margin=8,
     ):
@@ -365,6 +405,11 @@ class Grids(problem.Problem):
 
         img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
 
+        if comments is not None:
+            comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+            comment_img = torch.cat(comment_img, dim=0)
+            img = torch.cat([img, comment_img], dim=2)
+
         image_name = os.path.join(result_dir, filename)
 
         torchvision.utils.save_image(
@@ -1414,11 +1459,14 @@ class Grids(problem.Problem):
 
         return quizzes
 
-    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+    def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False):
         S = self.height * self.width
 
         if tasks is None:
-            tasks = self.all_tasks
+            if science:
+                tasks = self.science_tasks
+            else:
+                tasks = self.world_tasks
 
         quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
 
@@ -1439,9 +1487,10 @@ class Grids(problem.Problem):
 
         return quizzes
 
-    def save_some_examples(self, result_dir):
+    def save_some_examples(self, result_dir, science=False):
         nb, nrow = 128, 4
-        for t in self.all_tasks:
+        tasks = self.science_tasks if science else self.world_tasks
+        for t in tasks:
             print(t.__name__)
             quizzes = self.generate_w_quizzes_(nb, tasks=[t])
             self.save_quizzes_as_image(
@@ -1496,7 +1545,8 @@ if __name__ == "__main__":
     nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
-    # for t in grids.all_tasks:
+    # for t in grids.world_tasks:
+
     for t in [grids.task_path]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
@@ -1504,6 +1554,7 @@ if __name__ == "__main__":
             "/tmp",
             t.__name__ + ".png",
             quizzes,
+            comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))],
         )
 
     # exit(0)
diff --git a/main.py b/main.py
index 257f40f..b49fa06 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -112,12 +112,27 @@ grids_tasks = ", ".join(
 )
 
 parser.add_argument(
-    "--grids_tasks",
+    "--grids_world_tasks",
     type=str,
     default=None,
     help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
 )
 
+parser.add_argument(
+    "--grids_science_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+assert (
+    len(
+        set(args.grids_world_tasks.split(","))
+        & set(args.grids_science_tasks.split(","))
+    )
+    == 0
+), "World and science task have to be disjoint"
+
 ######################################################################
 
 parser.add_argument("--sky_height", type=int, default=6)
@@ -290,14 +305,17 @@ if args.problem == "sky":
         nb_threads=args.nb_threads,
     )
     back_accuracy = False
+
 elif args.problem == "grids":
     problem = grids.Grids(
         max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
-        tasks=args.grids_tasks,
+        world_tasks=args.grids_world_tasks,
+        science_tasks=args.grids_science_tasks,
     )
     back_accuracy = True
+
 else:
     raise ValueError
 
@@ -465,9 +483,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # fail(s)
 
         # This is nb_quizzes x nb_models
-        number_correct_responses = 0
 
-        remains = [c_quizzes.size(0)]
+        number_correct_responses = 0
+        nb_remaining = [c_quizzes.size(0)]
 
         for r in range(args.nb_rounds):
             if c_quizzes.size(0) == 0:
@@ -487,7 +505,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             c_quizzes = c_quizzes[to_keep]
             number_correct_responses = number_correct_responses[to_keep]
 
-            remains.append(c_quizzes.size(0))
+            nb_remaining.append(c_quizzes.size(0))
 
         if c_quizzes.size(0) > 0:
             nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
@@ -512,7 +530,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         else:
             e = "???"
 
-        v = " ".join([str(n) for n in remains])
+        v = " ".join([str(n) for n in nb_remaining])
         log_string(f"filter c_quizzes {v}")
 
         log_string(
@@ -526,11 +544,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     v_train = validated_quizzes[:nb_for_train]
     quiz_machine.store_c_quizzes(v_train, for_train=True)
-    quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
 
     v_test = validated_quizzes[nb_for_train:nb_to_validate]
     quiz_machine.store_c_quizzes(v_test, for_train=False)
-    quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
 
     ######################################################################
     # save images
@@ -538,19 +554,19 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
-        prefix = f"culture_c_quiz_{n_epoch:04d}"
-
         number_correct_responses = 0
         for r in range(args.nb_rounds):
             number_correct_responses += quiz_machine.models_successes(models, vq)
 
-        with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f:
-            for n, r in enumerate(number_correct_responses):
-                v = " ".join([str(n.item()) for n in r])
-                f.write(f"{n}: {v}\n")
+        comments = []
+        for r in number_correct_responses:
+            comments.append("nb_correct " + " ".join([str(n.item()) for n in r]))
 
         vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
-        quiz_machine.problem.save_quizzes_as_image(args.result_dir, prefix, vq)
+        filename = f"culture_c_quiz_{n_epoch:04d}.png"
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir, filename, vq, comments=comments
+        )
 
 
 ######################################################################
@@ -574,16 +590,22 @@ for k in range(args.nb_gpts):
     model.main_test_accuracy = 0.0
     model.id = k
 
-    quiz_machine.create_w_quizzes(
-        model=model,
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
+    model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+        args.nb_train_samples
     )
 
+    model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
     models.append(model)
 
 ######################################################################
 
+science_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+    args.nb_test_samples, science=True
+)
+
+######################################################################
+
 current_epoch = 0
 
 if args.resume:
@@ -728,6 +750,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         c_quizzes = quiz_machine.problem.reconfigure(
             c_quizzes, ("A", "f_A", "B", "f_B")
         )
+
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir,
             f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
index 2ca584e..4048b39 100755 (executable)
@@ -172,14 +172,16 @@ class QuizMachine:
                 from_w = torch.arange(
                     quizzes.size(0), device=quizzes.device
                 ) < w_quizzes.size(0)
-                i = torch.randperm(quizzes.size(0), device=quizzes.device)
-
-                return quizzes[i], from_w[i]
 
             else:
-                return w_quizzes, torch.full(
-                    (w_quizzes.size(0),), True, device=w_quizzes.device
-                )
+                quizzes = w_quizzes.clone()
+                from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
+
+            self.randomize_configuations_inplace(quizzes, structs=self.train_struct)
+
+            i = torch.randperm(quizzes.size(0), device=quizzes.device)
+
+            return quizzes[i], from_w[i]
 
     ######################################################################
 
@@ -199,7 +201,6 @@ class QuizMachine:
             input=result,
             ar_mask=ar_mask,
             seq_logproba=seq_logproba,
-            deterministic_synthesis=False,
             progress_bar_desc="accuracy",
             device=self.device,
         )
@@ -219,7 +220,9 @@ class QuizMachine:
         result = input.new(input.size())
         correct = input.new(input.size(0))
         predicted_parts = input.new(input.size(0), 4)
+
         nb = 0
+
         for struct, mask in [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
@@ -278,47 +281,38 @@ class QuizMachine:
                 quizzes[r == c], struct=structs[c]
             )
 
-    def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
-        model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
-        model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
-
-        self.randomize_configuations_inplace(
-            model.train_w_quizzes, structs=self.train_struct
-        )
-
-        self.randomize_configuations_inplace(
-            model.test_w_quizzes, structs=self.train_struct
-        )
-
     ######################################################################
 
     def renew_train_w_quizzes(self, model):
         if hasattr(model, "hard_w_quizzes"):
-            self.logger(
-                f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
-            )
-
             if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+                nb_to_generate = 0
                 model.train_w_quizzes[...] = model.hard_w_quizzes[
                     torch.randperm(hard_w_quizzes.size(0))[
                         model.train_w_quizzes.size(0)
                     ]
                 ]
             else:
+                nb_to_generate = model.train_w_quizzes.size(
+                    0
+                ) - model.hard_w_quizzes.size(0)
                 model.train_w_quizzes[...] = torch.cat(
                     [
                         model.hard_w_quizzes,
-                        self.problem.generate_w_quizzes(
-                            model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
-                        ),
+                        self.problem.generate_w_quizzes(nb_to_generate),
                     ],
                     dim=0,
                 )
         else:
+            nb_to_generate = 0
             model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
                 model.train_w_quizzes.size(0)
             )
 
+        self.logger(
+            f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+        )
+
         self.randomize_configuations_inplace(
             model.train_w_quizzes, structs=self.train_struct
         )
@@ -409,7 +403,6 @@ class QuizMachine:
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                deterministic_synthesis=False,
                 device=self.device,
             )
 
@@ -430,7 +423,6 @@ class QuizMachine:
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                deterministic_synthesis=False,
                 device=self.device,
             )
 
@@ -451,9 +443,8 @@ class QuizMachine:
         temperature_hot=1.0,
         temperature_cold=1.0,
     ):
-        c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to(
-            self.device
-        )
+        c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B"))
+        c_quizzes = c_quizzes.to(self.device)
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
@@ -469,7 +460,6 @@ class QuizMachine:
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
             device=self.device,
         )
 
@@ -482,7 +472,6 @@ class QuizMachine:
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
-            deterministic_synthesis=False,
             device=self.device,
         )
 
@@ -497,7 +486,6 @@ class QuizMachine:
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
-            deterministic_synthesis=False,
             device=self.device,
         )