Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index fd8ab41..9d95034 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -410,6 +410,10 @@ def create_c_quizzes(
 
     nb_to_create = nb_for_train + nb_for_test
 
 
     nb_to_create = nb_for_train + nb_for_test
 
+    warnings.warn(
+        f"{args.nb_gpts=} {args.nb_models_for_generation=} {args.min_to_validate=} {args.max_to_validate=}"
+    )
+
     while nb_validated() < nb_to_create:
         (
             new_c_quizzes,
     while nb_validated() < nb_to_create:
         (
             new_c_quizzes,
@@ -437,7 +441,8 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+        nv = " ".join([str(x.item()) for x in nv])
 
         log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
 
 
         log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")