From: François Fleuret Date: Tue, 16 Jul 2024 18:03:53 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ecc38fb79c17159785077617ae0b7ea0cad62fac;p=culture.git Update. --- diff --git a/main.py b/main.py index 957e95a..6df33bd 100755 --- a/main.py +++ b/main.py @@ -442,7 +442,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e})" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)" ) validated_quizzes = torch.cat(recorded, dim=0) diff --git a/quiz_machine.py b/quiz_machine.py index 0b84b36..faa640e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -542,11 +542,7 @@ class QuizMachine: ############################################################### def generate_c_quizzes( - self, - nb, - model_for_generation, - forward_only=False, - generation_temperature=1.0 + self, nb, model_for_generation, forward_only=False, generation_temperature=1.0 ): c_quizzes = torch.empty( nb, @@ -578,7 +574,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1.0 + temperature=1.0, deterministic_synthesis=False, device=self.device, )