Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 19:54:45 +0000 (21:54 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 19:54:45 +0000 (21:54 +0200)
main.py

diff --git a/main.py b/main.py
index 9801702..120e19c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1134,7 +1134,7 @@ def targets_and_prediction(model, input, mask_generate):
     return targets, logits
 
 
-def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
+def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
 
@@ -1147,7 +1147,8 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
             args.nb_test_samples,
             data_structures,
             local_device,
-            "test",
+            c_quizzes=c_quizzes,
+            desc="test",
         ):
             targets, logits = targets_and_prediction(model, input, mask_generate)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
@@ -1167,6 +1168,7 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
             args.nb_test_samples,
             data_structures,
             local_device,
+            c_quizzes,
             "test",
         ):
             targets = input.clone()
@@ -1282,7 +1284,7 @@ def one_ae_epoch(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
+    run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device)
 
 
 ######################################################################