From 57a13bdaf395838f93dcd67dce3151e2ed9eb3f1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 22:23:06 +0300 Subject: [PATCH] Update. --- grids.py | 8 +++++++- main.py | 2 ++ quiz_machine.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/grids.py b/grids.py index 03fa375..659bd6c 100755 --- a/grids.py +++ b/grids.py @@ -640,9 +640,15 @@ class Grids(problem.Problem): self.task_bounce, self.task_scale, self.task_symbols, - self.task_islands, + # self.task_islands, ] + def trivial_prompts_and_answers(self, prompts, answers): + S = self.height * self.width + Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] + f_Bs = answers + return (B_s == f_Bs).long().min(dim=-1).values > 0 + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): if tasks is None: tasks = self.all_tasks() diff --git a/main.py b/main.py index a2a771f..00a9492 100755 --- a/main.py +++ b/main.py @@ -421,6 +421,8 @@ def create_c_quizzes( temperature=args.generation_temperature, ) + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + nb_correct, seq_logproba = quiz_machine.compute_correctness( c_quizzes, models, diff --git a/quiz_machine.py b/quiz_machine.py index 45b2247..26a0d8b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -117,6 +117,17 @@ class QuizMachine: ).all() return i_forward, i_backward + def non_trivial(self, quizzes): + quizzes = quizzes.clone() + n_forward = quizzes[quizzes[:, 0] == self.token_forward] + n_backward = quizzes[:, 0] == self.token_backward + backward = quizzes[n_backward] + quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + return not self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) + def reverse_time(self, quizzes): i_forward, i_backward = self.indices_forward_and_backward(quizzes) -- 2.20.1