From: François Fleuret Date: Sat, 24 Aug 2024 14:16:14 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=8f17719e388a3800b8fd894cb407394320415a32;p=culture.git Update. --- diff --git a/main.py b/main.py index d3d237e..e4bb494 100755 --- a/main.py +++ b/main.py @@ -970,6 +970,7 @@ def test_ae(local_device=main_device): targets, input = degrade_input( input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations ) + input_with_mask = NTC_channel_cat(input, mask_generate, rho) output = model(input_with_mask) loss = NTC_masked_cross_entropy(output, targets, mask_loss) @@ -1039,7 +1040,7 @@ def test_ae(local_device=main_device): f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png" + filename = f"prediction_ae_{n_epoch:04d}_{ns}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, @@ -1049,7 +1050,7 @@ def test_ae(local_device=main_device): correct_parts=correct_parts, ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") if args.test == "ae":