X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=f0fb4082beb73913460bd4f72d4448e22ed15a1c;hb=00f7b3d445af8bb57376faabbf74eadc145faf1f;hp=26a0d8b2c40afe8372fb783f45da63fea69f210e;hpb=57a13bdaf395838f93dcd67dce3151e2ed9eb3f1;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 26a0d8b..f0fb408 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature=1.0, - deterministic_synthesis=False, + temperature, + deterministic_synthesis, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -123,9 +123,11 @@ class QuizMachine: n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return not self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) ) def reverse_time(self, quizzes):