Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 11eb8fd..9d95034 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -79,6 +79,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--reverse_cleanup", action="store_true", default=False)
+
 parser.add_argument("--problem", type=str, default="sky")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
@@ -408,6 +410,10 @@ def create_c_quizzes(
 
     nb_to_create = nb_for_train + nb_for_test
 
+    warnings.warn(
+        f"{args.nb_gpts=} {args.nb_models_for_generation=} {args.min_to_validate=} {args.max_to_validate=}"
+    )
+
     while nb_validated() < nb_to_create:
         (
             new_c_quizzes,
@@ -418,6 +424,7 @@ def create_c_quizzes(
             nb_models_for_generation=args.nb_models_for_generation,
             models=models,
             mode=args.generation_mode,
+            reverse_cleanup=args.reverse_cleanup,
             min_ave_seq_logproba=min_ave_seq_logproba,
             n_epoch=n_epoch,
             result_dir=args.result_dir,
@@ -434,7 +441,8 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+        nv = " ".join([str(x.item()) for x in nv])
 
         log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")