From 663591568040408d78d24add1d88eaa16ad6a7ab Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 Jul 2024 15:39:25 +0200 Subject: [PATCH] Update. --- main.py | 6 ++++-- quiz_machine.py | 23 ++++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 07fec96..4673f42 100755 --- a/main.py +++ b/main.py @@ -451,7 +451,7 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): ) log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finish {e})" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})" ) # store the new c_quizzes which have been validated @@ -479,7 +479,9 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): # s = " ".join([str(x.item()) for x in l]) # logp_file.write(s + "\n") - quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq) + quiz_machine.save_quiz_illustrations( + args.result_dir, prefix, vq, show_part_to_predict=False + ) ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index bcb89ec..70daa0b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -275,6 +275,7 @@ class QuizMachine: filename_prefix, quizzes, mistakes=None, + show_part_to_predict=True, ): quizzes = quizzes.clone().to("cpu") n_forward = quizzes[quizzes[:, 0] == self.token_forward] @@ -283,16 +284,20 @@ class QuizMachine: assert n_forward.size(0) + backward.size(0) == quizzes.size(0) quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - predicted_prompts = n_backward.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes.to("cpu") - predicted_answers *= mistakes.to("cpu") + if show_part_to_predict: + predicted_prompts = n_backward.long() + predicted_answers = 1 - predicted_prompts + if mistakes is not None: + # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct + predicted_prompts *= mistakes.to("cpu") + predicted_answers *= mistakes.to("cpu") + else: + # 0/2 ~ not-to-predict / to predict + predicted_prompts *= 2 + predicted_answers *= 2 else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 + predicted_prompts = None + predicted_answers = None self.problem.save_quiz_illustrations( result_dir, -- 2.39.5