Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 20:08:28 +0000 (23:08 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 20:08:28 +0000 (23:08 +0300)
reasoning.py

index b8d39ee..c442947 100755 (executable)
@@ -169,9 +169,13 @@ class Reasoning(problem.Problem):
     def nb_token_values(self):
         return len(self.colors)
 
+    # That's quite a tensorial spaghetti mess to sample
+    # non-overlapping rectangles quickly, but made the generation of
+    # 100k samples from 1h50 with a lame pure python code to 4min with
+    # this one.
     def rec_coo(self, x, n, min_height=3, min_width=3):
         K = 3
-        N = 4000
+        N = 1000
 
         while True:
             v = (
@@ -365,12 +369,8 @@ class Reasoning(problem.Problem):
             self.task_frame,
             self.task_detect,
         ]
-        prompts = torch.zeros(
-            nb, self.height, self.width * 3, dtype=torch.int64, device=self.device
-        )
-        answers = torch.zeros(
-            nb, self.height, self.width, dtype=torch.int64, device=self.device
-        )
+        prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
+        answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
 
         for prompt, answer in tqdm.tqdm(
@@ -385,6 +385,7 @@ class Reasoning(problem.Problem):
             f_B = answer
             task = tasks[torch.randint(len(tasks), (1,))]
             task(A, f_A, B, f_B)
+
         return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(