Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:56:21 +0000 (22:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:56:21 +0000 (22:56 +0200)
main.py

diff --git a/main.py b/main.py
index 70ca672..d21c54b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -814,10 +814,12 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
         :, :, 1
     ]
+    predicted_parts = correct_parts.abs()
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_prediction_{n_epoch}_{model.id}.png",
         quizzes=result[:128],
+        predicted_parts=predicted_parts[:128],
         correct_parts=correct_parts[:128],
     )