Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jul 2024 06:05:53 +0000 (08:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jul 2024 06:05:53 +0000 (08:05 +0200)
grids.py
main.py
quiz_machine.py

index bbb18d2..3c00abe 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -219,9 +219,9 @@ class Grids(problem.Problem):
             self.task_trajectory,
             self.task_bounce,
             self.task_scale,
-            self.task_symbols,
+            self.task_symbols,
             self.task_isometry,
-            self.task_islands,
+            self.task_islands,
         ]
 
         if tasks is None:
@@ -430,7 +430,8 @@ class Grids(problem.Problem):
             while True:
                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
-
+                i[:, 1] += 1
+                j[:, 1] += 1
                 big_enough = (
                     (i[:, 1] >= i[:, 0] + min_height)
                     & (j[:, 1] >= j[:, 0] + min_height)
@@ -903,27 +904,28 @@ class Grids(problem.Problem):
                 if d.min() > delta:
                     break
 
-            for k in range(1, nb_rec):
-                X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
             ai, aj = i.float().mean(), j.float().mean()
 
             q = torch.randint(3, (1,)).item() + 1
 
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
-
             assert i[q] != ai and j[q] != aj
 
+            for Z in [X, f_X]:
+                for k in range(1, nb_rec):
+                    Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+                Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+                Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+                Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+                Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+            # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+            f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+
             X[
                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
             ] = c[nb_rec]
 
-            f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
-
     # @torch.compile
     def task_isometry(self, A, f_A, B, f_B):
         nb_rec = 3
@@ -1271,6 +1273,52 @@ class Grids(problem.Problem):
             f_X[i, s : s + w1] = c1
             f_X[i, s + w1 : s + w1 + w2] = c2
 
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_proximity(self, A, f_A, B, f_B):
+        def rec_dist(a, b):
+            ai1, aj1, ai2, aj2 = a
+            bi1, bj1, bi2, bj2 = b
+            v = max(ai1 - bi2, bi1 - ai2)
+            h = max(aj1 - bj2, bj1 - aj2)
+            return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+                if min(d[1:]) == 0:
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if d[n] == 0:
+                    f_X[i1:i2, j1:j2] = c[0]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_corners(self, A, f_A, B, f_B):
+        polarity = torch.randint(2, (1,)).item()
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                if polarity == 0:
+                    X[i1, j1] = c[n]
+                    X[i2 - 1, j2 - 1] = c[n]
+                else:
+                    X[i1, j2 - 1] = c[n]
+                    X[i2 - 1, j1] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+
     ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
@@ -1371,7 +1419,8 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_convex]:
+    # for t in [grids.task_proximity, grids.task_corners]:
+    for t in [grids.task_symbols]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
diff --git a/main.py b/main.py
index 5ce9731..b7d0431 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -343,6 +343,7 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
         model.main_test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
             model=model,
+            input=full_input[:2000],
             result_dir=args.result_dir,
             deterministic_synthesis=deterministic_synthesis,
         )
index c006ea4..cbaa7cd 100755 (executable)
@@ -290,7 +290,9 @@ class QuizMachine:
 
     ######################################################################
 
-    def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
+    def produce_results(
+        self, n_epoch, model, input, result_dir, deterministic_synthesis
+    ):
         def compute_accuracy(input, log_prefix=None):
             input = input.to(self.device)
             ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
@@ -334,9 +336,7 @@ class QuizMachine:
 
             return result, correct
 
-        test_result, test_correct = compute_accuracy(
-            model.test_w_quizzes[:2000], log_prefix="test"
-        )
+        test_result, test_correct = compute_accuracy(input, log_prefix="test")
 
         n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward