Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 14:16:14 +0000 (16:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 14:16:14 +0000 (16:16 +0200)
main.py

diff --git a/main.py b/main.py
index d3d237e..e4bb494 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -970,6 +970,7 @@ def test_ae(local_device=main_device):
             targets, input = degrade_input(
                 input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
             )
+
             input_with_mask = NTC_channel_cat(input, mask_generate, rho)
             output = model(input_with_mask)
             loss = NTC_masked_cross_entropy(output, targets, mask_loss)
@@ -1039,7 +1040,7 @@ def test_ae(local_device=main_device):
                     f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
                 )
 
-                filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png"
+                filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
 
                 quiz_machine.problem.save_quizzes_as_image(
                     args.result_dir,
@@ -1049,7 +1050,7 @@ def test_ae(local_device=main_device):
                     correct_parts=correct_parts,
                 )
 
-            log_string(f"wrote {filename}")
+                log_string(f"wrote {filename}")
 
 
 if args.test == "ae":