X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quiz_machine.py;h=cb187beb03cd041500c47f2666bffc8ce9a2e2fb;hb=e0ab20005e2578edff27d4246c6904cf1047ed22;hp=de1e8d1c0bee7532f64a65d78c9af5a49a6b1821;hpb=e08d6e734387e6a7016521d3379c945f5cf2ad87;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index de1e8d1..cb187be 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression( t_next = dist.sample() all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next].sum(dim=-1) + + seq_logproba += logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -420,11 +421,11 @@ class QuizMachine: nb_correct = 0 + seq_logproba[...] = 0.0 + for model in models_for_validation: result = c_quizzes.clone() - seq_logproba[...] = 0.0 - ar_mask = self.make_ar_mask(result) masked_inplace_autoregression(