Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 19:36:50 +0000 (22:36 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 19:36:50 +0000 (22:36 +0300)
grids.py
quiz_machine.py

index 659bd6c..ed72099 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -647,7 +647,7 @@ class Grids(problem.Problem):
         S = self.height * self.width
         Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
         f_Bs = answers
-        return (B_s == f_Bs).long().min(dim=-1).values > 0
+        return (Bs == f_Bs).long().min(dim=-1).values > 0
 
     def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
         if tasks is None:
index 26a0d8b..9f4fe96 100755 (executable)
@@ -123,9 +123,11 @@ class QuizMachine:
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-        return not self.problem.trivial_prompts_and_answers(
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
         )
 
     def reverse_time(self, quizzes):