X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quiz_machine.py;h=c1477c9bb497498ddbe5aa4b2cc898ba1a915796;hb=f7a8cd141a39039048d3e9311220a33079f2cfc7;hp=26a0d8b2c40afe8372fb783f45da63fea69f210e;hpb=57a13bdaf395838f93dcd67dce3151e2ed9eb3f1;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 26a0d8b..c1477c9 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature=1.0, - deterministic_synthesis=False, + temperature, + deterministic_synthesis, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -123,9 +123,11 @@ class QuizMachine: n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return not self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) ) def reverse_time(self, quizzes): @@ -258,7 +260,7 @@ class QuizMachine: quizzes, mistakes=None, ): - quizzes = quizzes.clone() + quizzes = quizzes.clone().to("cpu") n_forward = quizzes[quizzes[:, 0] == self.token_forward] n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] @@ -269,8 +271,8 @@ class QuizMachine: predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes + predicted_prompts *= mistakes.to("cpu") + predicted_answers *= mistakes.to("cpu") else: # 0/2 ~ not-to-predict / to predict predicted_prompts *= 2 @@ -414,6 +416,25 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes) + def logproba_solution(self, models, c_quizzes): + logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + + for model in models: + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = ce.sum(dim=-1) + + return logproba + + ############################################################### + def compute_correctness( self, c_quizzes,