self.task_bounce,
self.task_scale,
self.task_symbols,
- self.task_islands,
+ # self.task_islands,
]
+ def trivial_prompts_and_answers(self, prompts, answers):
+ S = self.height * self.width
+ Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
+ f_Bs = answers
+ return (B_s == f_Bs).long().min(dim=-1).values > 0
+
def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
if tasks is None:
tasks = self.all_tasks()
temperature=args.generation_temperature,
)
+ c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+
nb_correct, seq_logproba = quiz_machine.compute_correctness(
c_quizzes,
models,
).all()
return i_forward, i_backward
+ def non_trivial(self, quizzes):
+ quizzes = quizzes.clone()
+ n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+ n_backward = quizzes[:, 0] == self.token_backward
+ backward = quizzes[n_backward]
+ quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+ return not self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
+
def reverse_time(self, quizzes):
i_forward, i_backward = self.indices_forward_and_backward(quizzes)