From 167c56ace610c3b975c702203bb7c7ddf74930ae Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 23:08:28 +0300 Subject: [PATCH] Update. --- reasoning.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/reasoning.py b/reasoning.py index b8d39ee..c442947 100755 --- a/reasoning.py +++ b/reasoning.py @@ -169,9 +169,13 @@ class Reasoning(problem.Problem): def nb_token_values(self): return len(self.colors) + # That's quite a tensorial spaghetti mess to sample + # non-overlapping rectangles quickly, but made the generation of + # 100k samples from 1h50 with a lame pure python code to 4min with + # this one. def rec_coo(self, x, n, min_height=3, min_width=3): K = 3 - N = 4000 + N = 1000 while True: v = ( @@ -365,12 +369,8 @@ class Reasoning(problem.Problem): self.task_frame, self.task_detect, ] - prompts = torch.zeros( - nb, self.height, self.width * 3, dtype=torch.int64, device=self.device - ) - answers = torch.zeros( - nb, self.height, self.width, dtype=torch.int64, device=self.device - ) + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) + answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width for prompt, answer in tqdm.tqdm( @@ -385,6 +385,7 @@ class Reasoning(problem.Problem): f_B = answer task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes( -- 2.39.5