Update.
[culture.git] / quiz_machine.py
index 26a0d8b..c1477c9 100755 (executable)
@@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
     input,
     ar_mask,
     seq_logproba,
-    temperature=1.0,
-    deterministic_synthesis=False,
+    temperature,
+    deterministic_synthesis,
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
@@ -123,9 +123,11 @@ class QuizMachine:
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-        return not self.problem.trivial_prompts_and_answers(
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
         )
 
     def reverse_time(self, quizzes):
         )
 
     def reverse_time(self, quizzes):
@@ -258,7 +260,7 @@ class QuizMachine:
         quizzes,
         mistakes=None,
     ):
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -269,8 +271,8 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
@@ -414,6 +416,25 @@ class QuizMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
+    def logproba_solution(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+        for model in models:
+            for input, l in zip(
+                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+            ):
+                ar_mask = self.make_ar_mask(input)
+                output = model(mygpt.BracketedSequence(input)).x
+                ce = (
+                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * ar_mask
+                )
+                l[:, model.id] = ce.sum(dim=-1)
+
+        return logproba
+
+    ###############################################################
+
     def compute_correctness(
         self,
         c_quizzes,
     def compute_correctness(
         self,
         c_quizzes,