X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=f0fb4082beb73913460bd4f72d4448e22ed15a1c;hb=93cea45f62046a3481d6c05ab2cfe70f6dbc93b3;hp=de1e8d1c0bee7532f64a65d78c9af5a49a6b1821;hpb=e08d6e734387e6a7016521d3379c945f5cf2ad87;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index de1e8d1..f0fb408 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature=1.0, - deterministic_synthesis=False, + temperature, + deterministic_synthesis, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression( t_next = dist.sample() all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next].sum(dim=-1) + + seq_logproba += logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -116,6 +117,19 @@ class QuizMachine: ).all() return i_forward, i_backward + def non_trivial(self, quizzes): + quizzes = quizzes.clone() + n_forward = quizzes[quizzes[:, 0] == self.token_forward] + n_backward = quizzes[:, 0] == self.token_backward + backward = quizzes[n_backward] + quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) + ) + def reverse_time(self, quizzes): i_forward, i_backward = self.indices_forward_and_backward(quizzes) @@ -359,11 +373,11 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total}" + f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" ) self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total}" + f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" ) return result, correct @@ -420,11 +434,11 @@ class QuizMachine: nb_correct = 0 + seq_logproba[...] = 0.0 + for model in models_for_validation: result = c_quizzes.clone() - seq_logproba[...] = 0.0 - ar_mask = self.make_ar_mask(result) masked_inplace_autoregression(