From 6d77075428ed95b0626ceb6ab55acabba03720ab Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 09:33:41 +0200 Subject: [PATCH] Update. --- main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 0311450..0056c76 100755 --- a/main.py +++ b/main.py @@ -598,7 +598,9 @@ def one_complete_epoch( # Compute the test accuracy - quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) + quizzes = generate_quiz_set( + args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier + ) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") correct = (quizzes == result).min(dim=1).values.long() @@ -611,7 +613,11 @@ def one_complete_epoch( ) save_inference_images( - model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device + model, + n_epoch, + train_c_quizzes, + args.c_quiz_multiplier, + local_device=local_device, ) -- 2.39.5