From 74b0f2faba5ee386388c4df87fabf353a177d833 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 01:16:49 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 717e8ac..c6c2f95 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -360,28 +360,28 @@ class QuizzMachine: result[n_backward], correct[n_backward] = compute_accuracy(back_input) if log_prefix is not None: - nb_correct = correct[n_forward].sum() - nb_total = correct[n_forward].size(0) - back_nb_correct = correct[n_backward].sum() - back_nb_total = correct[n_backward].size(0) + forward_nb_correct = correct[n_forward].sum() + forward_nb_total = correct[n_forward].size(0) + backward_nb_correct = correct[n_backward].sum() + backward_nb_total = correct[n_backward].size(0) self.logger( - f"accuracy {log_prefix} {n_epoch} {model.id=} {nb_correct} / {nb_total}" + f"forward_accuracy {log_prefix} {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}" ) self.logger( - f"back_accuracy {log_prefix} {n_epoch} {model.id=} {back_nb_correct} / {back_nb_total}" + f"backward_accuracy {log_prefix} {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}" ) return result, correct compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") - result, correct = compute_accuracy( + test_result, test_correct = compute_accuracy( self.test_w_quizzes[:nmax], log_prefix="test" ) - main_test_accuracy = correct.sum() / correct.size(0) + main_test_accuracy = test_correct.sum() / test_correct.size(0) self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") ############################## @@ -389,9 +389,9 @@ class QuizzMachine: self.save_quizzes( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=result[:72], + quizzes=test_result[:72], show_to_be_predicted=True, - mistakes=correct[:72] * 2 - 1, + mistakes=test_correct[:72] * 2 - 1, ) return main_test_accuracy -- 2.20.1