From 9b03df47520cff5c5da7f0655861a64ffc9c0e1a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 21:54:45 +0200 Subject: [PATCH] Update. --- main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 9801702..120e19c 100755 --- a/main.py +++ b/main.py @@ -1134,7 +1134,7 @@ def targets_and_prediction(model, input, mask_generate): return targets, logits -def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): +def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -1147,7 +1147,8 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): args.nb_test_samples, data_structures, local_device, - "test", + c_quizzes=c_quizzes, + desc="test", ): targets, logits = targets_and_prediction(model, input, mask_generate) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1167,6 +1168,7 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): args.nb_test_samples, data_structures, local_device, + c_quizzes, "test", ): targets = input.clone() @@ -1282,7 +1284,7 @@ def one_ae_epoch( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device) ###################################################################### -- 2.39.5