Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 05:02:16 +0000 (07:02 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 05:02:16 +0000 (07:02 +0200)
main.py

diff --git a/main.py b/main.py
index 9a8bd43..8f3568f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -477,110 +477,110 @@ c_quizzes_procedure = [
 ######################################################################
 
 
-def save_additional_results(models, science_w_quizzes):
+def save_additional_results(model, models, science_w_quizzes):
     # Save generated quizzes with the successive steps
 
-    for model in models:
-        recorder = []
+    recorder = []
 
-        c_quizzes = quiz_machine.generate_c_quizzes(
-            64,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-            recorder=recorder,
-        )
+    c_quizzes = quiz_machine.generate_c_quizzes(
+        64,
+        model_for_generation=model,
+        procedure=c_quizzes_procedure,
+        recorder=recorder,
+    )
 
-        ##
+    ##
 
-        probas = 0
+    probas = 0
 
-        for a in range(args.nb_averaging_rounds):
-            # This is nb_quizzes x nb_models
+    for a in range(args.nb_averaging_rounds):
+        # This is nb_quizzes x nb_models
 
-            seq_logproba = quiz_machine.models_logprobas(
-                models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            ) + quiz_machine.models_logprobas(
-                models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            )
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
 
-            probas += seq_logproba.exp()
+        probas += seq_logproba.exp()
 
-        probas /= args.nb_averaging_rounds
+    probas /= args.nb_averaging_rounds
 
-        comments = []
+    comments = []
 
-        for l in seq_logproba:
-            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+    for l in seq_logproba:
+        comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
 
-        ##
+    ##
 
-        c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
-        predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
-        nb_steps = c_quizzes.size(1)
-        c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
-        predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
+    c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
+    predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
+    nb_steps = c_quizzes.size(1)
+    c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
+    predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
 
-        # We have comments only for the final quiz, not the successive
-        # steps, so we have to add nb_steps-1 empty comments
+    # We have comments only for the final quiz, not the successive
+    # steps, so we have to add nb_steps-1 empty comments
 
-        steps_comments = []
-        for c in comments:
-            steps_comments += [""] * (nb_steps - 1) + [c]
+    steps_comments = []
+    for c in comments:
+        steps_comments += [""] * (nb_steps - 1) + [c]
 
-        filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            filename,
-            quizzes=c_quizzes,
-            predicted_parts=predicted_parts,
-            comments=steps_comments,
-            nrow=nb_steps * 2,  # two quiz per row
-        )
-        log_string(f"wrote {filename}")
+    filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=c_quizzes,
+        predicted_parts=predicted_parts,
+        comments=steps_comments,
+        nrow=nb_steps * 2,  # two quiz per row
+    )
+
+    log_string(f"wrote {filename}")
 
     ######################################################################
 
     if science_w_quizzes is not None:
-        for model in models:
-            struct = ("A", "f_A", "B", "f_B")
-            mask = (0, 0, 0, 1)
-            result, correct = quiz_machine.predict(
-                model=model,
-                quizzes=science_w_quizzes.to(main_device),
-                struct=struct,
-                mask=mask,
-            )
+        struct = ("A", "f_A", "B", "f_B")
+        mask = (0, 0, 0, 1)
+        result, correct = quiz_machine.predict(
+            model=model,
+            quizzes=science_w_quizzes.to(main_device),
+            struct=struct,
+            mask=mask,
+        )
 
-            predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
-                correct.size(0), -1
-            )
-            correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
+        predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
+            correct.size(0), -1
+        )
+        correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
 
-            nb_correct = (correct == 1).long().sum()
-            nb_total = (correct != 0).long().sum()
+        nb_correct = (correct == 1).long().sum()
+        nb_total = (correct != 0).long().sum()
 
-            log_string(
-                f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
-            )
+        log_string(
+            f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+        )
 
-            i = correct == 1
-            j = correct != 1
+        i = correct == 1
+        j = correct != 1
 
-            result = torch.cat([result[i], result[j]], dim=0)
-            correct = torch.cat([correct[i], correct[j]], dim=0)
-            correct_parts = predicted_parts * correct[:, None]
+        result = torch.cat([result[i], result[j]], dim=0)
+        correct = torch.cat([correct[i], correct[j]], dim=0)
+        correct_parts = predicted_parts * correct[:, None]
 
-            result = result[:128]
-            predicted_parts = predicted_parts[:128]
-            correct_parts = correct_parts[:128]
+        result = result[:128]
+        predicted_parts = predicted_parts[:128]
+        correct_parts = correct_parts[:128]
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
-                quizzes=result,
-                predicted_parts=predicted_parts,
-                correct_parts=correct_parts,
-            )
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
+            quizzes=result,
+            predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
+        )
 
 
 ######################################################################
@@ -1310,7 +1310,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
-    save_additional_results(weakest_models, science_w_quizzes)
+    for model in weakest_models:
+        save_additional_results(model, models, science_w_quizzes)
 
     ######################################################################