Merge branch 'dev'
[culture.git] / quiz_machine.py
index ef766c4..bc468d3 100755 (executable)
@@ -219,9 +219,6 @@ class QuizMachine:
     def generate_token_sequences(self, nb):
         prompts, answers = self.problem.generate_prompts_and_answers(nb)
 
-        print(f"DEBUG {prompts.size()=} {answers.size()=}")
-        sys.stdout.flush()
-
         if self.prompt_len is None:
             self.prompt_len = prompts.size(1)
 
@@ -453,9 +450,13 @@ class QuizMachine:
 
     ######################################################################
 
-    def logproba_of_solutions(self, models, c_quizzes):
+    def solution_token_logprobas(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+            c_quizzes.size(0),
+            len(models),
+            c_quizzes.size(1),
+            device=self.device,
+            dtype=torch.float32,
         )
 
         for model in models:
@@ -469,11 +470,12 @@ class QuizMachine:
                     input = input.to(self.device)
                     ar_mask = self.make_ar_mask(input)
                     output = model(mygpt.BracketedSequence(input)).x
-                    ce = (
-                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    l[:, model.id] = (
+                        -F.cross_entropy(
+                            output.transpose(1, 2), input, reduction="none"
+                        )
                         * ar_mask
                     )
-                    l[:, model.id] = -ce.sum(dim=-1)
 
                 model.train(t)