From a483486c06ab64b11c561862783918cda5ddf46c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 16 Sep 2024 22:56:21 +0200 Subject: [PATCH] Update. --- main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main.py b/main.py index 70ca672..d21c54b 100755 --- a/main.py +++ b/main.py @@ -814,10 +814,12 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[ :, :, 1 ] + predicted_parts = correct_parts.abs() quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_prediction_{n_epoch}_{model.id}.png", quizzes=result[:128], + predicted_parts=predicted_parts[:128], correct_parts=correct_parts[:128], ) -- 2.39.5