Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 3acf595..ca0d152 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -879,7 +879,7 @@ def run_tests(model, task, deterministic_synthesis):
 
             nb_test_samples += input.size(0)
 
-        task.produce_results(
+        main_test_accuracy = task.produce_results(
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
@@ -888,8 +888,11 @@ def run_tests(model, task, deterministic_synthesis):
         )
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
         log_string(f"test_perplexity {n_epoch} {test_perplexity}")
 
+    return main_test_accuracy
+
 
 ######################################################################
 
@@ -898,16 +901,16 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     one_epoch(model, task, learning_rate)
 
-    run_tests(model, task, deterministic_synthesis=False)
+    test_accuracy = run_tests(model, task, deterministic_synthesis=False)
 
     # --------------------------------------------
 
-    if n_epoch >= 3:
-        nb_required = 100
+    if test_accuracy >= 0.8:
+        nb_for_train, nb_for_test = 1000, 100
         kept = []
 
-        while sum([x.size(0) for x in kept]) < nb_required:
-            new_problems, nb_correct = task.create_new_problems(
+        while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+            new_quizzes, nb_correct = task.create_new_quizzes(
                 n_epoch=n_epoch,
                 result_dir=args.result_dir,
                 logger=log_string,
@@ -916,14 +919,17 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
                 nb_runs=10,
             )
 
-            to_keep = new_problems[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
-            log_string(f"keep {to_keep.size(0)} problems")
+            to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+            log_string(f"keep {to_keep.size(0)} quizzes")
             kept.append(to_keep)
 
-        new_problems = torch.cat(kept, dim=0)[:nb_required]
-        task.store_new_problems(new_problems)
+        new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+        task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+        task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
         task.save_image(
-            new_problems[:96],
+            new_quizzes[:96],
             args.result_dir,
             f"world_new_{n_epoch:04d}.png",
             log_string,