From: François Fleuret Date: Sat, 6 Jul 2024 19:36:50 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;ds=sidebyside;h=14c70d068471a163ecd389e0c1667e561ea056f9;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 659bd6c..ed72099 100755 --- a/grids.py +++ b/grids.py @@ -647,7 +647,7 @@ class Grids(problem.Problem): S = self.height * self.width Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] f_Bs = answers - return (B_s == f_Bs).long().min(dim=-1).values > 0 + return (Bs == f_Bs).long().min(dim=-1).values > 0 def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): if tasks is None: diff --git a/quiz_machine.py b/quiz_machine.py index 26a0d8b..9f4fe96 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -123,9 +123,11 @@ class QuizMachine: n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return not self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) ) def reverse_time(self, quizzes):