Update.
[culture.git] / quiz_machine.py
index f0fb408..c1477c9 100755 (executable)
@@ -260,7 +260,7 @@ class QuizMachine:
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -271,8 +271,8 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
@@ -416,6 +416,25 @@ class QuizMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
+    def logproba_solution(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+        for model in models:
+            for input, l in zip(
+                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+            ):
+                ar_mask = self.make_ar_mask(input)
+                output = model(mygpt.BracketedSequence(input)).x
+                ce = (
+                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * ar_mask
+                )
+                l[:, model.id] = ce.sum(dim=-1)
+
+        return logproba
+
+    ###############################################################
+
     def compute_correctness(
         self,
         c_quizzes,