Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index d0de5af..b7b55b5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -82,6 +82,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
+parser.add_argument("--nb_correct_to_validate", type=int, default=4)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
@@ -361,10 +363,9 @@ def create_c_quizzes(
 
     model_indexes = []
     sum_logits, sum_nb_c_quizzes = 0, 0
 
     model_indexes = []
     sum_logits, sum_nb_c_quizzes = 0, 0
-    nb_correct_to_validate = len(models) - 1
 
     while (
 
     while (
-        sum([x.size(0) for x in recorded[nb_correct_to_validate]])
+        sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
         < nb_for_train + nb_for_test
     ):
         nb_to_validate = nb_for_train + nb_for_test
         < nb_for_train + nb_for_test
     ):
         nb_to_validate = nb_for_train + nb_for_test
@@ -395,7 +396,7 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nb_validated = sum([x.size(0) for x in recorded[nb_correct_to_validate]])
+        nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
         nb_generated = sum(
             [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
         )
         nb_generated = sum(
             [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
         )
@@ -413,13 +414,13 @@ def create_c_quizzes(
         else:
             del recorded[n]
 
         else:
             del recorded[n]
 
-    new_c_quizzes = recorded[nb_correct_to_validate][: nb_for_train + nb_for_test]
+    new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test]
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
     for n in recorded.keys():
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
     for n in recorded.keys():
-        s = "_validated" if n == nb_correct_to_validate else ""
+        s = "_validated" if n == args.nb_correct_to_validate else ""
         quizz_machine.problem.save_quizzes(
             recorded[n][:72],
             args.result_dir,
         quizz_machine.problem.save_quizzes(
             recorded[n][:72],
             args.result_dir,