Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 14 Aug 2024 11:04:54 +0000 (13:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 14 Aug 2024 11:04:54 +0000 (13:04 +0200)
main.py

diff --git a/main.py b/main.py
index 0b9a86e..0bbcc6b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -528,7 +528,7 @@ def model_proba_solutions(m, quizzes):
 
 def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
-    nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
+    nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate // 10
 
     start_time = time.perf_counter()
 
@@ -759,11 +759,10 @@ class Thinker(nn.Module):
 
 
 if args.test == "func":
-    train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
     test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
 
-    L = train_input.size(1) // 4
-    f_len = 25
+    L = test_input.size(1) // 4
+    f_len = 50
 
     model = Thinker(
         vocabulary_size=vocabulary_size,
@@ -772,7 +771,7 @@ if args.test == "func":
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        f_len=20,
+        f_len=f_len,
         dropout=args.dropout,
     ).to(main_device)
 
@@ -781,6 +780,8 @@ if args.test == "func":
     for n_epoch in range(args.nb_epochs):
         model.train()
 
+        train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
         nb_train_samples, acc_train_loss = 0, 0.0
 
         for input in tqdm.tqdm(