Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jul 2024 07:13:48 +0000 (09:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jul 2024 07:13:48 +0000 (09:13 +0200)
grids.py

index 0b789ef..d3e7dcc 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -222,7 +222,7 @@ class Grids(problem.Problem):
             self.task_symbols,
             self.task_isometry,
             self.task_corners,
-            self.task_proximity,
+            self.task_contact,
             # self.task_islands, # TOO MESSY
         ]
 
@@ -1277,7 +1277,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     # [ai1,ai2] [bi1,bi2]
-    def task_proximity(self, A, f_A, B, f_B):
+    def task_contact(self, A, f_A, B, f_B):
         def rec_dist(a, b):
             ai1, aj1, ai2, aj2 = a
             bi1, bj1, bi2, bj2 = b
@@ -1297,10 +1297,12 @@ class Grids(problem.Problem):
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
                 if d[n] == 0:
-                    f_X[i1:i2, j1:j2] = c[0]
-                else:
-                    f_X[i1:i2, j1:j2] = c[n]
+                    f_X[i1, j1:j2] = c[0]
+                    f_X[i2 - 1, j1:j2] = c[0]
+                    f_X[i1:i2, j1] = c[0]
+                    f_X[i1:i2, j2 - 1] = c[0]
 
     # @torch.compile
     # [ai1,ai2] [bi1,bi2]
@@ -1421,8 +1423,8 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    # for t in [grids.task_proximity, grids.task_corners]:
-    for t in [grids.task_symbols]:
+    for t in [grids.task_contact]:
+        # for t in [grids.task_symbols]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())