Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 18b19db..ca0d152 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -474,6 +474,7 @@ elif args.task == "world":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
+        result_dir=args.result_dir,
         logger=log_string,
         device=device,
     )
         logger=log_string,
         device=device,
     )
@@ -853,7 +854,7 @@ def one_epoch(model, task, learning_rate):
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
-    log_string(f"train)perplexity {n_epoch} {train_perplexity}")
+    log_string(f"train_perplexity {n_epoch} {train_perplexity}")
 
 
 ######################################################################
 
 
 ######################################################################
@@ -878,7 +879,7 @@ def run_tests(model, task, deterministic_synthesis):
 
             nb_test_samples += input.size(0)
 
 
             nb_test_samples += input.size(0)
 
-        task.produce_results(
+        main_test_accuracy = task.produce_results(
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
@@ -887,7 +888,10 @@ def run_tests(model, task, deterministic_synthesis):
         )
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
         )
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-        log_string(f"test)perplexity {n_epoch} {test_perplexity}")
+
+        log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+
+    return main_test_accuracy
 
 
 ######################################################################
 
 
 ######################################################################
@@ -897,7 +901,39 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     one_epoch(model, task, learning_rate)
 
 
     one_epoch(model, task, learning_rate)
 
-    run_tests(model, task, deterministic_synthesis=True)
+    test_accuracy = run_tests(model, task, deterministic_synthesis=False)
+
+    # --------------------------------------------
+
+    if test_accuracy >= 0.8:
+        nb_for_train, nb_for_test = 1000, 100
+        kept = []
+
+        while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+            new_quizzes, nb_correct = task.create_new_quizzes(
+                n_epoch=n_epoch,
+                result_dir=args.result_dir,
+                logger=log_string,
+                nb=nb_required,
+                model=model,
+                nb_runs=10,
+            )
+
+            to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+            log_string(f"keep {to_keep.size(0)} quizzes")
+            kept.append(to_keep)
+
+        new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+        task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+        task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+        task.save_image(
+            new_quizzes[:96],
+            args.result_dir,
+            f"world_new_{n_epoch:04d}.png",
+            log_string,
+        )
 
     # --------------------------------------------
 
 
     # --------------------------------------------