Update.
[culture.git] / grids.py
index 7aec62c..002a33f 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -37,11 +37,34 @@ class Grids(problem.Problem):
         max_nb_cached_chunks=None,
         chunk_size=None,
         nb_threads=-1,
+        tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.height = 10
         self.width = 10
         self.cache_rec_coo = {}
+
+        all_tasks = [
+            self.task_replace_color,
+            self.task_translate,
+            self.task_grow,
+            self.task_half_fill,
+            self.task_frame,
+            self.task_detect,
+            self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
+            self.task_scale,
+            self.task_symbols,
+            self.task_isometry,
+            #            self.task_path,
+        ]
+
+        if tasks is None:
+            self.all_tasks = all_tasks
+        else:
+            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
     ######################################################################
@@ -398,7 +421,7 @@ class Grids(problem.Problem):
                     f_X[i1:i2, j1:j2] = c[n]
 
     # @torch.compile
-    def task_color_grow(self, A, f_A, B, f_B):
+    def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
@@ -715,7 +738,7 @@ class Grids(problem.Problem):
             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
 
     # @torch.compile
-    def task_ortho(self, A, f_A, B, f_B):
+    def task_isometry(self, A, f_A, B, f_B):
         nb_rec = 3
         di, dj = torch.randint(3, (2,)) - 1
         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
@@ -939,23 +962,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def all_tasks(self):
-        return [
-            self.task_replace_color,
-            self.task_translate,
-            self.task_grow,
-            self.task_color_grow,
-            self.task_frame,
-            self.task_detect,
-            self.task_count,
-            self.task_trajectory,
-            self.task_bounce,
-            self.task_scale,
-            self.task_symbols,
-            self.task_ortho,
-            #            self.task_path,
-        ]
-
     def trivial_prompts_and_answers(self, prompts, answers):
         S = self.height * self.width
         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
@@ -964,7 +970,7 @@ class Grids(problem.Problem):
 
     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
-            tasks = self.all_tasks()
+            tasks = self.all_tasks
 
         S = self.height * self.width
         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
@@ -990,7 +996,7 @@ class Grids(problem.Problem):
 
         return prompts.flatten(1), answers.flatten(1)
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -1012,10 +1018,10 @@ class Grids(problem.Problem):
 
     def save_some_examples(self, result_dir):
         nb, nrow = 72, 4
-        for t in self.all_tasks():
+        for t in self.all_tasks:
             print(t.__name__)
             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
-            self.save_quizzes(
+            self.save_quiz_illustrations(
                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
             )
 
@@ -1043,18 +1049,23 @@ if __name__ == "__main__":
     nb, nrow = 72, 4
     # nb, nrow = 8, 2
 
-    # for t in grids.all_tasks():
-    for t in [grids.task_puzzle]:
+    # for t in grids.all_tasks:
+    for t in [
+        grids.task_replace_color,
+        grids.task_frame,
+    ]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
+        grids.save_quiz_illustrations(
+            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        )
 
     exit(0)
 
     nb = 1000
 
-    for t in grids.all_tasks():
-        # for t in [ grids.task_replace_color ]: #grids.all_tasks():
+    for t in grids.all_tasks:
+        # for t in [ grids.task_replace_color ]: #grids.all_tasks:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
@@ -1066,7 +1077,7 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quizzes(
+    grids.save_quiz_illustrations(
         "/tmp",
         "test",
         prompts[:nb],