From 14c70d068471a163ecd389e0c1667e561ea056f9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 22:36:50 +0300 Subject: [PATCH] Update. --- grids.py | 2 +- quiz_machine.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/grids.py b/grids.py index 659bd6c..ed72099 100755 --- a/grids.py +++ b/grids.py @@ -647,7 +647,7 @@ class Grids(problem.Problem): S = self.height * self.width Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] f_Bs = answers - return (B_s == f_Bs).long().min(dim=-1).values > 0 + return (Bs == f_Bs).long().min(dim=-1).values > 0 def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): if tasks is None: diff --git a/quiz_machine.py b/quiz_machine.py index 26a0d8b..9f4fe96 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -123,9 +123,11 @@ class QuizMachine: n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return not self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) ) def reverse_time(self, quizzes): -- 2.39.5