From 99afb56af8da06ea364a5f1c7158c179697bb0c9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 10 Sep 2024 13:20:45 +0200 Subject: [PATCH] Update. --- main.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index fed8abc..c1ef5bc 100755 --- a/main.py +++ b/main.py @@ -1080,11 +1080,7 @@ def quiz_validation(models, c_quizzes, local_device): mask_generate = quiz_machine.make_quiz_mask( quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - result = ae_generate( - model, - (1 - mask_generate) * c_quizzes, - mask_generate, - ) + result = ae_generate(model, (1 - mask_generate) * c_quizzes, mask_generate) nb_mistakes = (result != c_quizzes).long().sum(dim=1) correct = correct & (nb_mistakes == 0) wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong) @@ -1201,9 +1197,11 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): def save_c_quizzes_with_scores(models, c_quizzes, filename): l = [] - for model in models: - model.eval().to(main_device) - l.append(model_ae_proba_solutions(model, c_quizzes)) + + with torch.autograd.no_grad(): + for model in models: + model = copy.deepcopy(model).to(main_device).eval() + l.append(model_ae_proba_solutions(model, c_quizzes)) probas = torch.cat([x[:, None] for x in l], dim=1) @@ -1215,7 +1213,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename): quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, - quizzes=subset_c_quizzes, + quizzes=c_quizzes, comments=comments, delta=True, nrow=8, @@ -1233,9 +1231,7 @@ if args.resume: filename = f"ae_{model.id:03d}.pth" try: - d = torch.load( - os.path.join(args.result_dir, filename), map_location=main_device - ) + d = torch.load(os.path.join(args.result_dir, filename), map_location="cpu") model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] -- 2.39.5