From 354a99c2e20a6e5fea923a45477d0ea9ac3306a0 Mon Sep 17 00:00:00 2001
From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= <francois@fleuret.org>
Date: Tue, 10 Sep 2024 10:09:35 +0200
Subject: [PATCH] Update.

---
 main.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/main.py b/main.py
index fe1aed1..fed8abc 100755
--- a/main.py
+++ b/main.py
@@ -1200,7 +1200,10 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 
 
 def save_c_quizzes_with_scores(models, c_quizzes, filename):
-    l = [model_ae_proba_solutions(model, c_quizzes) for model in models]
+    l = []
+    for model in models:
+        model.eval().to(main_device)
+        l.append(model_ae_proba_solutions(model, c_quizzes))
 
     probas = torch.cat([x[:, None] for x in l], dim=1)
 
-- 
2.39.5