Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 19:48:58 +0000 (21:48 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 19:48:58 +0000 (21:48 +0200)
grids.py
main.py
quiz_machine.py

index ba09225..85d640d 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -41,6 +41,7 @@ class Grids(problem.Problem):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.height = 10
         self.width = 10
+        self.cache_rec_coo = {}
         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
     ######################################################################
@@ -199,41 +200,134 @@ class Grids(problem.Problem):
     def nb_token_values(self):
         return len(self.colors)
 
-    def rec_coo(self, nb_rec, min_height=3, min_width=3):
-        N = 10
+    # @torch.compile
+    def rec_coo(
+        self,
+        nb_rec,
+        min_height=3,
+        min_width=3,
+        surface_max=None,
+        prevent_overlap=False,
+    ):
+        if surface_max is None:
+            surface_max = self.height * self.width // 2
+
+        signature = (nb_rec, min_height, min_width, surface_max)
+
+        try:
+            return self.cache_rec_coo[signature].pop()
+        except IndexError:
+            pass
+        except KeyError:
+            pass
+
+        N = 10000
         while True:
-            i = torch.randint(self.height, (N, nb_rec, 2)).sort(dim=-1).values
-            j = torch.randint(self.width, (N, nb_rec, 2)).sort(dim=-1).values
-            if nb_rec == 2:
-                A_i1, A_i2, A_j1, A_j2 = i[:, 0, 0], i[:, 0, 1], j[:, 0, 0], j[:, 0, 1]
-                B_i1, B_i2, B_j1, B_j2 = i[:, 1, 0], i[:, 1, 1], j[:, 1, 0], j[:, 1, 1]
-                no_overlap = torch.logical_not(
-                    (A_i1 > B_i2) & (A_i2 < B_i1) & (A_j1 > B_j1) & (A_j2 < B_j1)
+            while True:
+                i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+                j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
+
+                big_enough = (
+                    (i[:, 1] >= i[:, 0] + min_height)
+                    & (j[:, 1] >= j[:, 0] + min_height)
+                    & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
                 )
-                i, j = i[no_overlap], j[no_overlap]
-            elif nb_rec == 3:
-                A_i1, A_i2, A_j1, A_j2 = i[:, 0, 0], i[:, 0, 1], j[:, 0, 0], j[:, 0, 1]
-                B_i1, B_i2, B_j1, B_j2 = i[:, 1, 0], i[:, 1, 1], j[:, 1, 0], j[:, 1, 1]
-                C_i1, C_i2, C_j1, C_j2 = i[:, 2, 0], i[:, 2, 1], j[:, 2, 0], j[:, 2, 1]
-                no_overlap = (
-                    torch.logical_not(
-                        (A_i1 > B_i2) & (A_i2 < B_i1) & (A_j1 > B_j1) & (A_j2 < B_j1)
+
+                i, j = i[big_enough], j[big_enough]
+
+                n = i.size(0) - i.size(0) % nb_rec
+
+                if n > 0:
+                    break
+
+            i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+            j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+            if prevent_overlap:
+                can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
+                    dim=-1
+                ) <= self.height * self.width
+                i, j = i[can_fit], j[can_fit]
+                if nb_rec == 2:
+                    A_i1, A_i2, A_j1, A_j2 = (
+                        i[:, 0, 0],
+                        i[:, 0, 1],
+                        j[:, 0, 0],
+                        j[:, 0, 1],
                     )
-                    & torch.logical_not(
-                        (A_i1 > C_i2) & (A_i2 < C_i1) & (A_j1 > C_j1) & (A_j2 < C_j1)
+                    B_i1, B_i2, B_j1, B_j2 = (
+                        i[:, 1, 0],
+                        i[:, 1, 1],
+                        j[:, 1, 0],
+                        j[:, 1, 1],
                     )
-                    & torch.logical_not(
-                        (B_i1 > C_i2) & (B_i2 < C_i1) & (B_j1 > C_j1) & (B_j2 < C_j1)
+                    no_overlap = torch.logical_not(
+                        (A_i1 >= B_i2)
+                        & (A_i2 <= B_i1)
+                        & (A_j1 >= B_j1)
+                        & (A_j2 <= B_j1)
                     )
-                )
-                i, j = (i[no_overlap], j[no_overlap])
-            else:
-                assert nb_rec == 1
+                    i, j = i[no_overlap], j[no_overlap]
+                elif nb_rec == 3:
+                    A_i1, A_i2, A_j1, A_j2 = (
+                        i[:, 0, 0],
+                        i[:, 0, 1],
+                        j[:, 0, 0],
+                        j[:, 0, 1],
+                    )
+                    B_i1, B_i2, B_j1, B_j2 = (
+                        i[:, 1, 0],
+                        i[:, 1, 1],
+                        j[:, 1, 0],
+                        j[:, 1, 1],
+                    )
+                    C_i1, C_i2, C_j1, C_j2 = (
+                        i[:, 2, 0],
+                        i[:, 2, 1],
+                        j[:, 2, 0],
+                        j[:, 2, 1],
+                    )
+                    no_overlap = (
+                        (
+                            (A_i1 >= B_i2)
+                            | (A_i2 <= B_i1)
+                            | (A_j1 >= B_j2)
+                            | (A_j2 <= B_j1)
+                        )
+                        & (
+                            (A_i1 >= C_i2)
+                            | (A_i2 <= C_i1)
+                            | (A_j1 >= C_j2)
+                            | (A_j2 <= C_j1)
+                        )
+                        & (
+                            (B_i1 >= C_i2)
+                            | (B_i2 <= C_i1)
+                            | (B_j1 >= C_j2)
+                            | (B_j2 <= C_j1)
+                        )
+                    )
+                    i, j = (i[no_overlap], j[no_overlap])
+                else:
+                    assert nb_rec == 1
 
             if i.size(0) > 1:
                 break
 
-        return [(i[0, k, 0], j[0, k, 0], i[0, k, 1], j[0, k, 1]) for k in range(nb_rec)]
+        self.cache_rec_coo[signature] = [
+            [
+                (
+                    i[n, k, 0].item(),
+                    j[n, k, 0].item(),
+                    i[n, k, 1].item(),
+                    j[n, k, 1].item(),
+                )
+                for k in range(nb_rec)
+            ]
+            for n in range(i.size(0))
+        ]
+
+        return self.cache_rec_coo[signature].pop()
 
     ######################################################################
 
@@ -242,7 +336,7 @@ class Grids(problem.Problem):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
@@ -255,7 +349,7 @@ class Grids(problem.Problem):
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(nb_rec)
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
                 i1, j1, i2, j2 = r[nb_rec - 1]
                 if (
                     i1 + di >= 0
@@ -281,7 +375,7 @@ class Grids(problem.Problem):
         direction = torch.randint(2, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(nb_rec)
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
                 i1, j1, i2, j2 = r[nb_rec - 1]
                 if i1 + 3 < i2 and j1 + 3 < j2:
                     break
@@ -306,7 +400,7 @@ class Grids(problem.Problem):
         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[2 * n]
@@ -346,20 +440,24 @@ class Grids(problem.Problem):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                f_X[i1:i2, j1:j2] = c[n]
                 if n == nb_rec - 1:
-                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+                    f_X[i1:i2, j1] = c[n]
+                    f_X[i1:i2, j2 - 1] = c[n]
+                    f_X[i1, j1:j2] = c[n]
+                    f_X[i2 - 1, j1:j2] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
 
     # @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
@@ -773,19 +871,18 @@ if __name__ == "__main__":
     # exit(0)
 
     # if True:
-    # nb = 72
-
-    # for t in grids.all_tasks():
-    # for t in [grids.task_count]:
-    # print(t.__name__)
-    # prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-    # grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+    nb = 72
 
-    # exit(0)
+    for t in grids.all_tasks():
+        # for t in [grids.task_replace_color]:
+        print(t.__name__)
+        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
 
     nb = 1000
 
     for t in grids.all_tasks():
+        # for t in [ grids.task_replace_color ]: #grids.all_tasks():
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index ba5f04b..3004f9c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -437,7 +437,7 @@ def create_c_quizzes(
             if c_quizzes.size(0) > 0:
                 logproba = c_quizzes.new(c_quizzes.size(0), len(models))
                 for q, l in zip(
-                    c_quizzes.split(args.batch_size), logits.split(args.batch_size)
+                    c_quizzes.split(args.batch_size), logproba.split(args.batch_size)
                 ):
                     for model in models:
                         l[model.id] = F.cross_entropy(model(q))
index f0fb408..321df35 100755 (executable)
@@ -260,7 +260,7 @@ class QuizMachine:
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -271,8 +271,8 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2