Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 13:53:01 +0000 (15:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 13:53:01 +0000 (15:53 +0200)
main.py

diff --git a/main.py b/main.py
index dda62af..5ef583c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -126,13 +126,6 @@ parser.add_argument(
     help="A comma-separated subset of: " + grids_tasks + ".",
 )
 
-parser.add_argument(
-    "--grids_science_tasks",
-    type=str,
-    default=None,
-    help="A comma-separated subset of: " + grids_tasks + ", or None.",
-)
-
 ######################################################################
 
 parser.add_argument("--sky_height", type=int, default=6)
@@ -152,14 +145,6 @@ args = parser.parse_args()
 if args.result_dir is None:
     args.result_dir = f"results_culture"
 
-assert not args.grids_science_tasks or (
-    len(
-        set(args.grids_world_tasks.split(","))
-        & set(args.grids_science_tasks.split(","))
-    )
-    == 0
-), "World and science tasks have to be disjoint"
-
 ######################################################################
 
 default_args = {
@@ -302,43 +287,12 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
-if args.problem == "sky":
-    problem = sky.Sky(
-        height=args.sky_height,
-        width=args.sky_width,
-        nb_birds=args.sky_nb_birds,
-        nb_iterations=args.sky_nb_iterations,
-        speed=args.sky_speed,
-        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-        chunk_size=100,
-        nb_threads=args.nb_threads,
-    )
-
-elif args.problem == "grids":
-    problem = grids.Grids(
-        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-        chunk_size=100,
-        nb_threads=args.nb_threads,
-        tasks=args.grids_world_tasks,
-    )
-
-    if args.grids_science_tasks is None:
-        science_w_quizzes = None
-    else:
-        science_problem = grids.Grids(
-            max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-            chunk_size=100,
-            nb_threads=args.nb_threads,
-            tasks=args.grids_science_tasks,
-        )
-        science_w_quizzes = science_problem.generate_w_quizzes(100)
-
-        if not args.resume:
-            science_problem.save_some_examples(args.result_dir, "science_")
-
-
-else:
-    raise ValueError
+problem = grids.Grids(
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+    chunk_size=100,
+    nb_threads=args.nb_threads,
+    tasks=args.grids_world_tasks,
+)
 
 if not args.resume:
     problem.save_some_examples(args.result_dir)
@@ -516,7 +470,7 @@ c_quizzes_procedure = [
 ######################################################################
 
 
-def save_additional_results(model, models, science_w_quizzes):
+def save_additional_results(model, models):
     # Save generated quizzes with the successive steps
 
     recorder = []
@@ -576,49 +530,6 @@ def save_additional_results(model, models, science_w_quizzes):
 
     log_string(f"wrote {filename}")
 
-    ######################################################################
-
-    if science_w_quizzes is not None:
-        struct = ("A", "f_A", "B", "f_B")
-        mask = (0, 0, 0, 1)
-        result, correct, _ = quiz_machine.predict(
-            model=model,
-            quizzes=science_w_quizzes.to(main_device),
-            struct=struct,
-            mask=mask,
-        )
-
-        predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
-            correct.size(0), -1
-        )
-        correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
-
-        nb_correct = (correct == 1).long().sum()
-        nb_total = (correct != 0).long().sum()
-
-        log_string(
-            f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
-        )
-
-        i = correct == 1
-        j = correct != 1
-
-        result = torch.cat([result[i], result[j]], dim=0)
-        correct = torch.cat([correct[i], correct[j]], dim=0)
-        correct_parts = predicted_parts * correct[:, None]
-
-        result = result[:128]
-        predicted_parts = predicted_parts[:128]
-        correct_parts = correct_parts[:128]
-
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
-            quizzes=result,
-            predicted_parts=predicted_parts,
-            correct_parts=correct_parts,
-        )
-
 
 ######################################################################
 
@@ -642,6 +553,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
     for model in models:
         model.recorded_c_quizzes = []
 
+    teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
+
     while nb_validated < nb_to_validate:
         model_for_generation = models[torch.randint(len(models), (1,)).item()]
 
@@ -696,6 +609,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
                         proba_other_solutions[dont_get_it] = -1
                         i = proba_other_solutions.argmax()
                         model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
+                        teaching_count[i, model.id] += 1
                         nb_validated += 1
 
         duration = time.perf_counter() - start_time
@@ -715,6 +629,10 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
             f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
         )
 
+    for s in range(teaching_count.size(0)):
+        o = [x.item() for x in teaching_count[s]]
+        log_string(f"teacher model {s} to {o}")
+
     for model in models:
         new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
 
@@ -723,15 +641,29 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
             if n > 0:
                 model.train_c_quiz_bags.append(new_bag[:n])
             if n < new_bag.size(0):
-                model.test_c_quiz_bags.append(new_bag[:n])
+                model.test_c_quiz_bags.append(new_bag[n:])
 
-            vq = new_bag[:128]
+            c_quizzes = new_bag[:128]
 
-            seq_logprobas = quiz_machine.models_logprobas(
-                models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            ) + quiz_machine.models_logprobas(
-                models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            )
+            l = [
+                quiz_machine.models_logprobas(
+                    model,
+                    c_quizzes,
+                    ("A", "f_A", "B", "f_B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
+                )
+                + quiz_machine.models_logprobas(
+                    model,
+                    c_quizzes,
+                    ("f_A", "A", "f_B", "B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
+                )
+                for model in models
+            ]
+
+            seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
 
             probas = seq_logprobas.exp()
 
@@ -744,7 +676,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
 
             filename = f"culture_c_quiz_{n_epoch:04d}.png"
             quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir, filename, vq, comments=comments
+                args.result_dir, filename, c_quizzes, comments=comments
             )
 
         log_string(
@@ -916,7 +848,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         log_string(f"wrote {filename}")
 
     for model in weakest_models:
-        save_additional_results(model, models, science_w_quizzes)
+        save_additional_results(model, models)
 
     ######################################################################