From 55cae88e5196986b8e64823d577230b25bf99950 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 28 Jul 2024 00:17:19 +0200 Subject: [PATCH] Update. --- main.py | 2 +- quiz_machine.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index fbc6f42..47952af 100755 --- a/main.py +++ b/main.py @@ -552,7 +552,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_to_generate_per_iteration, model_for_generation=model, procedure=c_quizzes_procedure, - # to_recycle=to_recycle, + to_recycle=to_recycle, ) # We discard the trivial ones, according to a criterion diff --git a/quiz_machine.py b/quiz_machine.py index d4b463b..18c0828 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -433,7 +433,7 @@ class QuizMachine: ############################################################### - def generate_c_quizzes(self, nb, model_for_generation, procedure): + def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): seq_logproba = torch.zeros(nb, device=self.device) c_quizzes = None @@ -454,6 +454,11 @@ class QuizMachine: logit_transformer=t, ) + if to_recycle is not None: + to_recycle = self.problem.reconfigure(to_recycle, s) + c_quizzes[: to_recycle.size(0)] = to_recycle + to_recycle = None + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) return c_quizzes.to("cpu") -- 2.39.5