Update.
[culture.git] / quiz_machine.py
index d4af770..c1477c9 100755 (executable)
@@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    temperature=1.0,
-    deterministic_synthesis=False,
+    temperature,
+    deterministic_synthesis,
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
@@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression(
             t_next = dist.sample()
 
         all_n = torch.arange(t_next.size(0))
-        seq_logproba += logits[all_n, t_next].sum(dim=-1)
+
+        seq_logproba += logits[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -116,6 +117,19 @@ class QuizMachine:
         ).all()
         return i_forward, i_backward
 
+    def non_trivial(self, quizzes):
+        quizzes = quizzes.clone()
+        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+        n_backward = quizzes[:, 0] == self.token_backward
+        backward = quizzes[n_backward]
+        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
+        )
+
     def reverse_time(self, quizzes):
         i_forward, i_backward = self.indices_forward_and_backward(quizzes)
 
@@ -246,7 +260,7 @@ class QuizMachine:
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -257,8 +271,8 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
@@ -359,11 +373,11 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}"
+                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
                 )
 
                 self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}"
+                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
                 )
 
             return result, correct
@@ -402,6 +416,25 @@ class QuizMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
+    def logproba_solution(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+        for model in models:
+            for input, l in zip(
+                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+            ):
+                ar_mask = self.make_ar_mask(input)
+                output = model(mygpt.BracketedSequence(input)).x
+                ce = (
+                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * ar_mask
+                )
+                l[:, model.id] = ce.sum(dim=-1)
+
+        return logproba
+
+    ###############################################################
+
     def compute_correctness(
         self,
         c_quizzes,
@@ -420,11 +453,11 @@ class QuizMachine:
 
         nb_correct = 0
 
+        seq_logproba[...] = 0.0
+
         for model in models_for_validation:
             result = c_quizzes.clone()
 
-            seq_logproba[...] = 0.0
-
             ar_mask = self.make_ar_mask(result)
 
             masked_inplace_autoregression(