From 3ea12df1dcfc4006eb895fd62bb622e9aef6178c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 09:51:57 +0300 Subject: [PATCH] Update. --- main.py | 6 +++--- reasoning.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 02e1a8d..ff573c4 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, reasoning, quiz_machine +import sky, grids, quiz_machine # world quizzes vs. culture quizzes @@ -251,8 +251,8 @@ if args.problem == "sky": speed=args.sky_speed, ) back_accuracy = False -elif args.problem == "reasoning": - problem = reasoning.Reasoning(device=device) +elif args.problem == "grids": + problem = grids.Grids(device=device) back_accuracy = True else: raise ValueError diff --git a/reasoning.py b/reasoning.py index 058c410..9462f87 100755 --- a/reasoning.py +++ b/reasoning.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Reasoning(problem.Problem): +class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), ("red", [255, 0, 0]), @@ -421,7 +421,7 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] - def contact(X, i, j, q): + def contact(self, X, i, j, q): nq, nq_diag = 0, 0 no = 0 @@ -466,7 +466,7 @@ class Reasoning(problem.Problem): k = torch.randperm(self.height * self.width) for p in range(self.height * self.width): i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = contact(X, i, j, c[q[p]]) + no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) if no == 0 and nq_diag == 0: if nq == 0: if nb[q[p]] < self.width: @@ -693,19 +693,20 @@ if __name__ == "__main__": nb = 48 - reasoning = Reasoning() + grids = Grids() - for t in [reasoning.task_islands]: # reasoning.all_tasks(): + for t in grids.all_tasks(): + # for t in [grids.task_islands]: print(t.__name__) - prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) - reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) exit(0) nb = 72 start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(nb) + prompts, answers = grids.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") @@ -713,7 +714,7 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - reasoning.save_quizzes( + grids.save_quizzes( "/tmp", "test", prompts[:nb], -- 2.39.5