From d3d4ce7bb2b799f4bf81a936987e3a8938514af8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 07:37:52 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index c6c2f95..92b5799 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -238,7 +238,7 @@ class QuizzMachine: result_dir, "culture_w_quizzes", self.train_w_quizzes[:72], - show_to_be_predicted=True, + n_backward=self.train_w_quizzes[:72, 0] == self.token_backward, ) def save_quizzes( @@ -246,7 +246,7 @@ class QuizzMachine: result_dir, filename_prefix, quizzes, - show_to_be_predicted=False, + n_backward=None, mistakes=None, ): quizzes = quizzes.clone() @@ -256,8 +256,11 @@ class QuizzMachine: assert forward.size(0) + backward.size(0) == quizzes.size(0) quizzes[ib] = self.reverse_time(quizzes[ib]) - if show_to_be_predicted: - predicted_prompts = ib.long() + if n_backward is None: + predicted_prompts = None + predicted_answers = None + else: + predicted_prompts = n_backward.long() predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct @@ -267,9 +270,6 @@ class QuizzMachine: # 0/2 ~ not-to-predict / to predict predicted_prompts *= 2 predicted_answers *= 2 - else: - predicted_prompts = None - predicted_answers = None self.problem.save_quizzes( result_dir, @@ -390,7 +390,7 @@ class QuizzMachine: result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], - show_to_be_predicted=True, + n_backward=self.test_w_quizzes[:72, 0] == self.token_backward, mistakes=test_correct[:72] * 2 - 1, ) -- 2.20.1