Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 13:55:59 +0000 (15:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 13:55:59 +0000 (15:55 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index f51ab38..fbebbb9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -574,7 +574,7 @@ def save_additional_results(model, models, science_w_quizzes):
     if science_w_quizzes is not None:
         struct = ("A", "f_A", "B", "f_B")
         mask = (0, 0, 0, 1)
-        result, correct = quiz_machine.predict(
+        result, correct, _ = quiz_machine.predict(
             model=model,
             quizzes=science_w_quizzes.to(main_device),
             struct=struct,
@@ -650,14 +650,33 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
 
+        seq_logproba = torch.zeros(
+            c_quizzes.size(0), len(models), device=solved_c_quizzes.device
+        )
+
         for m in models:
-            solved_c_quizzes[:, m.id] = quiz_machine.predict(
+            (
+                solved_c_quizzes[:, m.id],
+                _,
+                seq_logproba[:, m.id],
+            ) = quiz_machine.predict(
                 m,
                 solved_c_quizzes[:, m.id],
                 struct=("A", "f_A", "B", "f_B"),
                 mask=(0, 0, 0, 1),
             )
 
+        #!!!!!!!!!!!!!!!!!!!!
+        l = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        for s in range(seq_logproba.size(0)):
+            print(f"-- {s=} ----------------")
+            for m in range(seq_logproba.size(1)):
+                print("DEBUG", seq_logproba[s, m].item(), l[s, m].item())
+        exit(0)
+        #!!!!!!!!!!!!!!!!!!!!!!!!!
+
         # FINISH
 
         seq_logproba = quiz_machine.models_logprobas(
@@ -1314,7 +1333,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         record_new_c_quizzes(
             models,
             quiz_machine,
-            nb_errorsfor_train=args.nb_new_c_quizzes_for_train,
+            nb_for_train=args.nb_new_c_quizzes_for_train,
             nb_for_test=args.nb_new_c_quizzes_for_test,
         )
 
index 1d89cf4..6aa4e9b 100755 (executable)
@@ -28,7 +28,7 @@ def one_batch_masked_inplace_autoregression(
     model,
     input,
     ar_mask,
-    seq_logproba,
+    acc_seq_logproba,
     deterministic_synthesis=False,
 ):
     if input.size(0) == 0:
@@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression(
 
         all_n = torch.arange(t_next.size(0))
 
-        seq_logproba += logits[all_n, t_next]
+        acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -107,14 +107,11 @@ class QuizMachine:
         model,
         input,
         ar_mask,
-        seq_logproba=None,
+        seq_logproba,
         progress_bar_desc=None,
     ):
         assert input.size() == ar_mask.size()
 
-        if seq_logproba is None:
-            seq_logproba = torch.empty(input.size(0), device=self.device)
-
         batches = zip(
             input.split(self.batch_size),
             ar_mask.split(self.batch_size),
@@ -138,7 +135,7 @@ class QuizMachine:
                     model=model,
                     input=input,
                     ar_mask=ar_mask,
-                    seq_logproba=seq_logproba,
+                    acc_seq_logproba=seq_logproba,
                     deterministic_synthesis=False,
                 )
 
@@ -190,10 +187,11 @@ class QuizMachine:
     ######################################################################
 
     def predict(self, model, quizzes, struct, mask):
+        quizzes = quizzes.to(self.device)
         ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
-        seq_logproba = torch.empty(quizzes.size(0), device=self.device)
+        seq_logproba = torch.zeros(quizzes.size(0), device=self.device)
 
         self.autoregression(
             model=model,
@@ -205,7 +203,11 @@ class QuizMachine:
 
         correct = (result == quizzes).min(dim=1).values.long()
 
-        return result, correct
+        result = result.to("cpu")
+        correct = correct.to("cpu")
+        seq_logproba = seq_logproba.to("cpu")
+
+        return result, correct, seq_logproba
 
     ######################################################################
 
@@ -221,7 +223,7 @@ class QuizMachine:
         for struct, mask_generate, _, _ in self.test_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
-            result[i], correct[i] = self.predict(
+            result[i], correct[i], _ = self.predict(
                 model=model, quizzes=input[i], struct=struct, mask=mask_generate
             )
             predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[