Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 08:08:05 +0000 (10:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 08:08:05 +0000 (10:08 +0200)
grids.py
main.py
quiz_machine.py

index cfc7d16..5dad6f3 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -37,11 +37,34 @@ class Grids(problem.Problem):
         max_nb_cached_chunks=None,
         chunk_size=None,
         nb_threads=-1,
+        tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.height = 10
         self.width = 10
         self.cache_rec_coo = {}
+
+        all_tasks = [
+            self.task_replace_color,
+            self.task_translate,
+            self.task_grow,
+            self.task_half_fill,
+            self.task_frame,
+            self.task_detect,
+            self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
+            self.task_scale,
+            self.task_symbols,
+            self.task_isometry,
+            #            self.task_path,
+        ]
+
+        if tasks is None:
+            self.all_tasks = all_tasks
+        else:
+            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
     ######################################################################
@@ -398,7 +421,7 @@ class Grids(problem.Problem):
                     f_X[i1:i2, j1:j2] = c[n]
 
     # @torch.compile
-    def task_color_grow(self, A, f_A, B, f_B):
+    def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
@@ -715,7 +738,7 @@ class Grids(problem.Problem):
             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
 
     # @torch.compile
-    def task_ortho(self, A, f_A, B, f_B):
+    def task_isometry(self, A, f_A, B, f_B):
         nb_rec = 3
         di, dj = torch.randint(3, (2,)) - 1
         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
@@ -939,23 +962,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def all_tasks(self):
-        return [
-            self.task_replace_color,
-            self.task_translate,
-            self.task_grow,
-            self.task_color_grow,
-            self.task_frame,
-            self.task_detect,
-            self.task_count,
-            self.task_trajectory,
-            self.task_bounce,
-            self.task_scale,
-            self.task_symbols,
-            self.task_ortho,
-            #            self.task_path,
-        ]
-
     def trivial_prompts_and_answers(self, prompts, answers):
         S = self.height * self.width
         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
@@ -964,7 +970,7 @@ class Grids(problem.Problem):
 
     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
-            tasks = self.all_tasks()
+            tasks = self.all_tasks
 
         S = self.height * self.width
         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
@@ -1012,7 +1018,7 @@ class Grids(problem.Problem):
 
     def save_some_examples(self, result_dir):
         nb, nrow = 72, 4
-        for t in self.all_tasks():
+        for t in self.all_tasks:
             print(t.__name__)
             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
             self.save_quizzes(
@@ -1043,7 +1049,7 @@ if __name__ == "__main__":
     nb, nrow = 72, 4
     # nb, nrow = 8, 2
 
-    # for t in grids.all_tasks():
+    # for t in grids.all_tasks:
     for t in [
         grids.task_replace_color,
         grids.task_frame,
@@ -1056,8 +1062,8 @@ if __name__ == "__main__":
 
     nb = 1000
 
-    for t in grids.all_tasks():
-        # for t in [ grids.task_replace_color ]: #grids.all_tasks():
+    for t in grids.all_tasks:
+        # for t in [ grids.task_replace_color ]: #grids.all_tasks:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index b88cbc4..fc55b9c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -98,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
 
+grids_tasks = ", ".join(
+    [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
+
+parser.add_argument(
+    "--grids_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+######################################################################
+
 parser.add_argument("--sky_height", type=int, default=6)
 
 parser.add_argument("--sky_width", type=int, default=8)
@@ -250,6 +263,7 @@ elif args.problem == "grids":
         max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
+        tasks=args.grids_tasks,
     )
     back_accuracy = True
 else:
index 4f704a0..631d41b 100755 (executable)
@@ -416,7 +416,7 @@ class QuizMachine:
 
     def logproba_of_solutions(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device
+            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
         )
 
         for model in models: