Update.
[culture.git] / quiz_machine.py
index 321df35..c1477c9 100755 (executable)
@@ -416,6 +416,25 @@ class QuizMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
+    def logproba_solution(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+        for model in models:
+            for input, l in zip(
+                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+            ):
+                ar_mask = self.make_ar_mask(input)
+                output = model(mygpt.BracketedSequence(input)).x
+                ce = (
+                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * ar_mask
+                )
+                l[:, model.id] = ce.sum(dim=-1)
+
+        return logproba
+
+    ###############################################################
+
     def compute_correctness(
         self,
         c_quizzes,