Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 14:01:49 +0000 (16:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 14:01:49 +0000 (16:01 +0200)
main.py

diff --git a/main.py b/main.py
index e01e57a..fbfbdf8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -117,7 +117,7 @@ parser.add_argument("--prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--nb_hints", type=int, default=5)
 
-parser.add_argument("--nb_runs", type=int, default=5)
+parser.add_argument("--nb_runs", type=int, default=1)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -1218,7 +1218,7 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 ######################################################################
 
 
-def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False):
+def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False):
     l = []
 
     with torch.autograd.no_grad():
@@ -1424,10 +1424,13 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
         save_c_quizzes_with_scores(
-            models, c_quizzes, 256, filename, solvable_only=False
+            models, c_quizzes[:256], filename, solvable_only=False
         )
+
         filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png"
-        save_c_quizzes_with_scores(models, c_quizzes, 256, filename, solvable_only=True)
+        save_c_quizzes_with_scores(
+            models, c_quizzes[:256], filename, solvable_only=True
+        )
 
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")