Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 35f02a3..ca0d152 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -474,6 +474,7 @@ elif args.task == "world":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
+        result_dir=args.result_dir,
         logger=log_string,
         device=device,
     )
@@ -878,7 +879,7 @@ def run_tests(model, task, deterministic_synthesis):
 
             nb_test_samples += input.size(0)
 
-        task.produce_results(
+        main_test_accuracy = task.produce_results(
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
@@ -887,8 +888,11 @@ def run_tests(model, task, deterministic_synthesis):
         )
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
         log_string(f"test_perplexity {n_epoch} {test_perplexity}")
 
+    return main_test_accuracy
+
 
 ######################################################################
 
@@ -897,16 +901,16 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     one_epoch(model, task, learning_rate)
 
-    run_tests(model, task, deterministic_synthesis=False)
+    test_accuracy = run_tests(model, task, deterministic_synthesis=False)
 
     # --------------------------------------------
 
-    if n_epoch >= 3:
-        nb_required = 1000
+    if test_accuracy >= 0.8:
+        nb_for_train, nb_for_test = 1000, 100
         kept = []
 
-        while sum([x.size(0) for x in kept]) < nb_required:
-            new_problems, nb_correct = task.create_new_problems(
+        while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+            new_quizzes, nb_correct = task.create_new_quizzes(
                 n_epoch=n_epoch,
                 result_dir=args.result_dir,
                 logger=log_string,
@@ -915,11 +919,21 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
                 nb_runs=10,
             )
 
-            to_keep = new_problems[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
-            log_string(f"keep {to_keep.size(0)} problems")
+            to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+            log_string(f"keep {to_keep.size(0)} quizzes")
             kept.append(to_keep)
 
-        new_problems = torch.cat(kept, dim=0)[:nb_required]
+        new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+        task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+        task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+        task.save_image(
+            new_quizzes[:96],
+            args.result_dir,
+            f"world_new_{n_epoch:04d}.png",
+            log_string,
+        )
 
     # --------------------------------------------