Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 30dcd4d..0a7be99 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -79,14 +79,12 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--reverse_cleanup", action="store_true", default=False)
+
 parser.add_argument("--problem", type=str, default="sky")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--nb_models_for_generation", type=int, default=1)
-
-parser.add_argument("--generation_mode", type=str, default="groupthink")
-
 parser.add_argument("--min_to_validate", type=int, default=4)
 
 parser.add_argument("--max_to_validate", type=int, default=4)
@@ -95,6 +93,16 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
+parser.add_argument("--sky_height", type=int, default=6)
+
+parser.add_argument("--sky_width", type=int, default=8)
+
+parser.add_argument("--sky_nb_birds", type=int, default=3)
+
+parser.add_argument("--sky_nb_iterations", type=int, default=2)
+
+parser.add_argument("--sky_speed", type=int, default=3)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -222,9 +230,15 @@ assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 if args.problem == "sky":
-    problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2)
+    problem = sky.Sky(
+        height=args.sky_height,
+        width=args.sky_width,
+        nb_birds=args.sky_nb_birds,
+        nb_iterations=args.sky_nb_iterations,
+        speed=args.sky_speed,
+    )
 elif args.problem == "wireworld":
-    problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
 else:
     raise ValueError
 
@@ -365,94 +379,83 @@ def run_tests(model, quizz_machine, deterministic_synthesis):
 ######################################################################
 
 
+def valid_c_quizzes(recorded, criteria):
+    result = [q[criteria(c)] for q, c in recorded]
+    return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
+
+
+######################################################################
+
+
 def create_c_quizzes(
     models,
     quizz_machine,
     nb_for_train=1000,
     nb_for_test=100,
-    min_ave_seq_logproba=None,
 ):
-    # We will store the generated quizzes for each number of
-    # correct prediction
-    recorded = dict([(n, []) for n in range(len(models) + 1)])
+    recorded = []
 
-    model_indexes = []
     sum_logits, sum_nb_c_quizzes = 0, 0
 
-    def nb_generated():
-        return sum([sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()])
+    nb_to_create = nb_for_train + nb_for_test
 
-    def nb_validated():
-        return sum(
-            [
-                sum([x.size(0) for x in recorded[n]])
-                for n in range(args.min_to_validate, args.max_to_validate + 1)
-            ]
-        )
+    # ------------------------------------------------------------
 
-    nb_to_create = nb_for_train + nb_for_test
+    standard_validity = lambda nb_correct: torch.logical_and(
+        nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+    )
 
-    while nb_validated() < nb_to_create:
-        (
-            new_c_quizzes,
-            nb_correct,
-            ave_seq_logproba,
-        ) = quizz_machine.gang_create_c_quizzes(
-            nb=nb_to_create,
-            nb_models_for_generation=args.nb_models_for_generation,
-            models=models,
-            mode=args.generation_mode,
-            min_ave_seq_logproba=min_ave_seq_logproba,
-            n_epoch=n_epoch,
-            result_dir=args.result_dir,
+    while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create:
+        model_for_generation = models[torch.randint(len(models), (1,))]
+
+        c_quizzes, ave_seq_logproba = quizz_machine.generate_quizzes(
+            nb_to_create,
+            model_for_generation=model_for_generation,
+            reverse_cleanup=args.reverse_cleanup,
         )
 
-        sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
-        sum_nb_c_quizzes += new_c_quizzes.size(0)
+        sum_logits += c_quizzes.size(0) * ave_seq_logproba
+        sum_nb_c_quizzes += c_quizzes.size(0)
+
+        nb_correct = quizz_machine.compute_correctness(
+            c_quizzes, models, both_direction=True
+        )
 
         if args.dirty_debug:
             nb_correct = torch.randint(
-                len(models) + 1, nb_correct.size(), device=new_c_quizzes.device
+                len(models) + 1, nb_correct.size(), device=c_quizzes.device
             )
 
-        for n in range(nb_correct.max() + 1):
-            recorded[n].append(new_c_quizzes[nb_correct == n].clone())
+        recorded.append((c_quizzes, nb_correct))
+
+        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+        nv = " ".join([str(x.item()) for x in nv])
+
+        nb_validated = valid_c_quizzes(recorded, standard_validity).size(0)
 
         log_string(
-            f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_create}"
+            f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
         )
 
-    # concatenate and shuffle
-    for n in recorded.keys():
-        if len(recorded[n]) > 0:
-            q = torch.cat(recorded[n], dim=0)
-            q = q[torch.randperm(q.size(0), device=q.device)]
-            recorded[n] = q
-        else:
-            del recorded[n]
-
-    new_c_quizzes = torch.cat(
-        [recorded[n] for n in range(args.min_to_validate, args.max_to_validate + 1)],
-        dim=0,
-    )
+    # store the new c_quizzes which have been validated
 
-    new_c_quizzes = new_c_quizzes[
-        torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[
-            : nb_for_train + nb_for_test
-        ]
-    ]
+    new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    for n in recorded.keys():
+    # save a bunch of images to investigate what quizzes with a
+    # certain nb of correct predictions look like
+
+    for n in range(len(models) + 1):
         s = (
             "_validated"
             if n >= args.min_to_validate and n <= args.max_to_validate
             else ""
         )
+
         quizz_machine.problem.save_quizzes(
-            recorded[n][:72],
+            valid_c_quizzes(recorded, criteria=lambda nb_correct: nb_correct == n)[:72],
             args.result_dir,
             f"culture_c_quiz_{n_epoch:04d}_N{n}{s}",
         )
@@ -487,57 +490,43 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-min_ave_seq_logproba = None
-
 for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
-    a = [(model.id, float(model.main_test_accuracy)) for model in models]
-    a.sort(key=lambda p: p[0])
-    s = " ".join([f"{p[1]*100:.02f}%" for p in a])
-    log_string(f"current accuracies {s}")
-
-    # select the model with lowest accuracy
-    models.sort(key=lambda model: model.main_test_accuracy)
-    model = models[0]
+    weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
 
     log_string(
-        f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+        f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
     )
 
     # improve it
-    one_epoch(model, quizz_machine)
-
-    quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+    one_epoch(weakest_model, quizz_machine)
 
     log_string(
         f"train_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
     # test it
-    run_tests(model, quizz_machine, deterministic_synthesis=False)
+    run_tests(weakest_model, quizz_machine, deterministic_synthesis=False)
 
     log_string(
         f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
+    cta = " ".join([f"{float(m.main_test_accuracy):.02f}" for m in models])
+    log_string(f"current_test_accuracies {cta}")
+
+    # replace a fraction of the w_quizzes with a fresh ones
+    quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+
     if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        ave_seq_logproba = create_c_quizzes(
+        create_c_quizzes(
             models,
             quizz_machine,
             nb_for_train=nb_new_c_quizzes_for_train,
             nb_for_test=nb_new_c_quizzes_for_test,
-            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
-        # We keep the first average logits as a reference
-        # if min_ave_seq_logproba is None:
-        # min_ave_seq_logproba = ave_seq_logproba
-        # else:
-        # log_string(
-        # f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}"
-        # )
-
         # We update everyone
         for model in models:
             run_tests(model, quizz_machine, deterministic_synthesis=False)