From ecc38fb79c17159785077617ae0b7ea0cad62fac Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 20:03:53 +0200 Subject: [PATCH] Update. --- main.py | 2 +- quiz_machine.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 957e95a..6df33bd 100755 --- a/main.py +++ b/main.py @@ -442,7 +442,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e})" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)" ) validated_quizzes = torch.cat(recorded, dim=0) diff --git a/quiz_machine.py b/quiz_machine.py index 0b84b36..faa640e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -542,11 +542,7 @@ class QuizMachine: ############################################################### def generate_c_quizzes( - self, - nb, - model_for_generation, - forward_only=False, - generation_temperature=1.0 + self, nb, model_for_generation, forward_only=False, generation_temperature=1.0 ): c_quizzes = torch.empty( nb, @@ -578,7 +574,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1.0 + temperature=1.0, deterministic_synthesis=False, device=self.device, ) -- 2.20.1