From 7943f4ee840012ed3d71e76f24b008776ab2b238 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jul 2024 08:21:17 +0200 Subject: [PATCH] Update. --- main.py | 6 +++--- quiz_machine.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index ab87b56..d9257db 100755 --- a/main.py +++ b/main.py @@ -376,7 +376,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - n_p2a = input[:, 0] == quiz_machine.token_p2a + n_p2a = input[:, 0] == quiz_machine.problem.token_forward to_store = from_w & n_p2a.to("cpu") if to_store.any(): hard_w_quizzes.append( @@ -496,11 +496,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 v_train = validated_quizzes[:nb_for_train] quiz_machine.store_c_quizzes(v_train, for_train=True) - quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_train), for_train=True) + quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True) v_test = validated_quizzes[nb_for_train:nb_to_create] quiz_machine.store_c_quizzes(v_test, for_train=False) - quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False) + quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False) ###################################################################### # save images diff --git a/quiz_machine.py b/quiz_machine.py index 51c3f08..cc81086 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -281,8 +281,8 @@ class QuizMachine: self.problem.save_quiz_illustrations( result_dir, filename_prefix, - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + quizzes[:, : self.prompt_len], + quizzes[:, self.prompt_len :], predicted_prompts, predicted_answers, ) @@ -358,9 +358,7 @@ class QuizMachine: if self.back_accuracy and n_a2p.any(): # accuracy of B->A*->B*=B instead of B->A*=A back_input = self.p_a_flip(result[n_a2p]) - back_input[:, 2 + self.prompt_len :] = input[ - n_a2p, 1 : 1 + self.answer_len - ] + back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len] _, correct[n_a2p] = compute_accuracy(back_input) if log_prefix is not None: -- 2.39.5