Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 06:54:27 +0000 (08:54 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 06:54:27 +0000 (08:54 +0200)
grids.py
main.py
quiz_machine.py

index 8d274ad..f12fcb9 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -167,6 +167,19 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
+    def inject_noise(self, quizzes, noise, struct, mask):
+        assert self.check_structure(quizzes, struct=struct)
+        S = self.height * self.width
+        mask = torch.tensor(mask, device=quizzes.device)
+        mask = mask[None, :, None].expand(1, 4, S + 1)
+        mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
+        mask = mask.reshape(1, -1).expand_as(quizzes)
+        random = torch.randint(self.nb_colors, mask.size())
+
+        quizzes = mask * random + (1 - mask) * quizzes
+
+        return quizzes
+
     # What a mess
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         if torch.is_tensor(quizzes):
@@ -237,7 +250,8 @@ class Grids(problem.Problem):
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
 
-        self.token_A = len(self.colors)
+        self.nb_colors = len(self.colors)
+        self.token_A = self.nb_colors
         self.token_f_A = self.token_A + 1
         self.token_B = self.token_f_A + 1
         self.token_f_B = self.token_B + 1
@@ -294,7 +308,7 @@ class Grids(problem.Problem):
     ######################################################################
 
     def grid2img(self, x, scale=15):
-        m = torch.logical_and(x >= 0, x < len(self.colors)).long()
+        m = torch.logical_and(x >= 0, x < self.nb_colors).long()
         y = self.colors[x * m].permute(0, 3, 1, 2)
         s = y.shape
         y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
@@ -562,7 +576,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_replace_color(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -578,7 +592,7 @@ class Grids(problem.Problem):
                 break
 
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -603,7 +617,7 @@ class Grids(problem.Problem):
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         direction = torch.randint(2, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -629,7 +643,7 @@ class Grids(problem.Problem):
     def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -670,7 +684,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_frame(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -687,7 +701,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -739,7 +753,7 @@ class Grids(problem.Problem):
 
             N = 3
             c = torch.zeros(N + 2, dtype=torch.int64)
-            c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+            c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
 
             for X, f_X in [(A, f_A), (B, f_B)]:
                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
@@ -785,7 +799,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 di, dj = torch.randint(7, (2,)) - 3
@@ -816,7 +830,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_bounce(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             # @torch.compile
             def free(i, j):
@@ -885,7 +899,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_scale(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
 
         i, j = (
             torch.randint(self.height // 2, (1,)).item(),
@@ -915,7 +929,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_symbols(self, A, f_A, B, f_B):
         nb_rec = 4
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         delta = 3
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -967,7 +981,7 @@ class Grids(problem.Problem):
                 X[...] = 0
                 f_X[...] = 0
 
-                c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+                c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
 
                 for r in range(nb_rec):
                     while True:
@@ -1034,7 +1048,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def REMOVED_task_distance(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         dist0 = torch.empty(self.height + 2, self.width + 2)
         dist1 = torch.empty(self.height + 2, self.width + 2)
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -1099,7 +1113,7 @@ class Grids(problem.Problem):
     def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
         S = 4
         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
-        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 f_X[...] = 0
@@ -1165,7 +1179,7 @@ class Grids(problem.Problem):
                             X[ii + i, jj + j] = c[d]
 
     def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
                 self.cache_islands = list(
@@ -1196,7 +1210,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def TOO_HARD_task_stack(self, A, f_A, B, f_B):
         N = 5
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             i1, j1, i2, j2 = (
                 self.height // 2 - 1,
@@ -1241,7 +1255,7 @@ class Grids(problem.Problem):
 
     def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
         N = 6
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
 
         for X, f_X in [(A, f_A), (B, f_B)]:
             M1 = torch.randint(2, (5, 5))
@@ -1257,7 +1271,7 @@ class Grids(problem.Problem):
 
     def TOO_HARD_task_compute(self, A, f_A, B, f_B):
         N = 6
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             v = torch.randint((self.width - 1) // 2, (N,)) + 1
             chain = torch.randperm(N)
@@ -1307,7 +1321,7 @@ class Grids(problem.Problem):
             return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
 
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -1330,7 +1344,7 @@ class Grids(problem.Problem):
     def task_corners(self, A, f_A, B, f_B):
         polarity = torch.randint(2, (1,)).item()
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
 
@@ -1367,7 +1381,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_path(self, A, f_A, B, f_B):
         nb_rec = 2
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 X[...] = 0
@@ -1403,7 +1417,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_fill(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             accept_full = torch.rand(1) < 0.5
 
@@ -1435,7 +1449,7 @@ class Grids(problem.Problem):
                     break
 
     def TOO_HARD_task_addition(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
             N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
@@ -1452,7 +1466,7 @@ class Grids(problem.Problem):
 
     def task_science_implicit(self, A, f_A, B, f_B):
         nb_rec = 5
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
 
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -1505,7 +1519,7 @@ class Grids(problem.Problem):
 
     def task_science_dot(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 X[...] = 0
@@ -1539,7 +1553,7 @@ class Grids(problem.Problem):
         return False
 
     def task_science_tag(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             rs = []
             while len(rs) < 4:
diff --git a/main.py b/main.py
index cce747a..72b2b26 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -43,8 +43,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
 parser.add_argument("--log_command", type=str, default=None)
 
-########################################
-
+# ----------------------------------
 parser.add_argument("--nb_epochs", type=int, default=10000)
 
 parser.add_argument("--batch_size", type=int, default=None)
@@ -63,8 +62,7 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
-########################################
-
+# ----------------------------------
 parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--dim_model", type=int, default=None)
@@ -79,8 +77,7 @@ parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
 
-########################################
-
+# ----------------------------------
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 parser.add_argument("--problem", type=str, default="grids")
@@ -89,6 +86,8 @@ parser.add_argument("--nb_threads", type=int, default=1)
 
 parser.add_argument("--gpus", type=str, default="all")
 
+# ----------------------------------
+
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=2)
@@ -103,7 +102,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=1)
 
-parser.add_argument("--nb_rounds", type=int, default=1)
+parser.add_argument("--prompt_noise", type=float, default=0.0)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -343,6 +342,7 @@ quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
     batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
+    prompt_noise=args.prompt_noise,
     logger=log_string,
     device=main_device,
 )
@@ -578,14 +578,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         c_quizzes = c_quizzes[to_keep]
 
-        # We go through nb_rounds rounds and keep only quizzes on
-        # which
-        #
-        # (1) models respond always the same through rounds, and
-        #
-        # (2) at least one and up to max_fail_to_validate model(s)
-        # fail(s)
-
         # This is nb_quizzes x nb_models
 
         seq_logproba = quiz_machine.models_logprobas(
index 90879ce..1e973c5 100755 (executable)
@@ -67,6 +67,7 @@ class QuizMachine:
         problem,
         batch_size,
         result_dir,
+        prompt_noise,
         logger,
         device=torch.device("cpu"),
     ):
@@ -78,6 +79,7 @@ class QuizMachine:
         self.logger = logger
         self.prompt_len = None
         self.answer_len = None
+        self.prompt_noise = prompt_noise
 
         self.understood_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
@@ -153,6 +155,7 @@ class QuizMachine:
 
             if len(c_quizzes) > 0:
                 c_quizzes = torch.cat(c_quizzes, dim=0)
+
                 if c_quizzes.size(0) > w_quizzes.size(0) // 2:
                     i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
                     c_quizzes = c_quizzes[i]
@@ -171,13 +174,19 @@ class QuizMachine:
                 quizzes = w_quizzes.clone()
                 from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
 
-            self.randomize_configuations_inplace(
-                quizzes, structs=[s for s, m in self.understood_structures]
+        i = torch.randperm(quizzes.size(0), device=quizzes.device)
+        quizzes, from_w = quizzes[i], from_w[i]
+
+        if self.prompt_noise > 0.0:
+            quizzes = self.problem.inject_noise(
+                quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0)
             )
 
-            i = torch.randperm(quizzes.size(0), device=quizzes.device)
+        self.randomize_configuations_inplace(
+            quizzes, structs=[s for s, m in self.understood_structures]
+        )
 
-            return quizzes[i], from_w[i]
+        return quizzes, from_w
 
     ######################################################################
 
@@ -267,20 +276,24 @@ class QuizMachine:
 
     def renew_train_w_quizzes(self, model):
         if hasattr(model, "hard_w_quizzes"):
-            if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+            hard_w_quizzes = self.problem.reconfigure(
+                model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B")
+            )
+            self.logger(
+                f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+            )
+            if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
                 nb_to_generate = 0
-                model.train_w_quizzes[...] = model.hard_w_quizzes[
+                model.train_w_quizzes[...] = hard_w_quizzes[
                     torch.randperm(hard_w_quizzes.size(0))[
                         model.train_w_quizzes.size(0)
                     ]
                 ]
             else:
-                nb_to_generate = model.train_w_quizzes.size(
-                    0
-                ) - model.hard_w_quizzes.size(0)
+                nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0)
                 model.train_w_quizzes[...] = torch.cat(
                     [
-                        model.hard_w_quizzes,
+                        hard_w_quizzes,
                         self.problem.generate_w_quizzes(nb_to_generate),
                     ],
                     dim=0,
@@ -291,10 +304,6 @@ class QuizMachine:
                 model.train_w_quizzes.size(0)
             )
 
-        self.logger(
-            f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
-        )
-
         self.randomize_configuations_inplace(
             model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
         )