Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 22:17:19 +0000 (00:17 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 22:17:19 +0000 (00:17 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index fbc6f42..47952af 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -552,7 +552,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             nb_to_generate_per_iteration,
             model_for_generation=model,
             procedure=c_quizzes_procedure,
-            to_recycle=to_recycle,
+            to_recycle=to_recycle,
         )
 
         # We discard the trivial ones, according to a criterion
index d4b463b..18c0828 100755 (executable)
@@ -433,7 +433,7 @@ class QuizMachine:
 
     ###############################################################
 
-    def generate_c_quizzes(self, nb, model_for_generation, procedure):
+    def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
         seq_logproba = torch.zeros(nb, device=self.device)
 
         c_quizzes = None
@@ -454,6 +454,11 @@ class QuizMachine:
                 logit_transformer=t,
             )
 
+            if to_recycle is not None:
+                to_recycle = self.problem.reconfigure(to_recycle, s)
+                c_quizzes[: to_recycle.size(0)] = to_recycle
+                to_recycle = None
+
             c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
 
         return c_quizzes.to("cpu")