X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=f0fb4082beb73913460bd4f72d4448e22ed15a1c;hb=00f7b3d445af8bb57376faabbf74eadc145faf1f;hp=45b2247f50519d382bc5e04bbab5612fb60fe698;hpb=719785dbea77989a54bf7592bb6919f2e8f3f6c5;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 45b2247..f0fb408 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature=1.0, - deterministic_synthesis=False, + temperature, + deterministic_synthesis, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -117,6 +117,19 @@ class QuizMachine: ).all() return i_forward, i_backward + def non_trivial(self, quizzes): + quizzes = quizzes.clone() + n_forward = quizzes[quizzes[:, 0] == self.token_forward] + n_backward = quizzes[:, 0] == self.token_backward + backward = quizzes[n_backward] + quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) + ) + def reverse_time(self, quizzes): i_forward, i_backward = self.indices_forward_and_backward(quizzes)