Update.
[culture.git] / quiz_machine.py
index de1e8d1..cb187be 100755 (executable)
@@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression(
             t_next = dist.sample()
 
         all_n = torch.arange(t_next.size(0))
-        seq_logproba += logits[all_n, t_next].sum(dim=-1)
+
+        seq_logproba += logits[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -420,11 +421,11 @@ class QuizMachine:
 
         nb_correct = 0
 
+        seq_logproba[...] = 0.0
+
         for model in models_for_validation:
             result = c_quizzes.clone()
 
-            seq_logproba[...] = 0.0
-
             ar_mask = self.make_ar_mask(result)
 
             masked_inplace_autoregression(