From ca5b98d1517b8ce2367887bbad2205f27d55e0b3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 5 Jul 2023 09:26:40 +0200 Subject: [PATCH] Update. --- main.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 32447bf..9dee679 100755 --- a/main.py +++ b/main.py @@ -1131,7 +1131,7 @@ class TaskExpr(Task): if vr is None or vr < 0: nb_missed += 1 else: - d = abs(vr-vi) + d = abs(vr - vi) if d >= nb_delta.size(0): nb_missed += 1 else: @@ -1141,7 +1141,12 @@ class TaskExpr(Task): return nb_total, nb_correct, nb_delta, nb_missed - test_nb_total, test_nb_correct, test_nb_delta, test_nb_missed = compute_nb_correct(self.test_input[:1000]) + ( + test_nb_total, + test_nb_correct, + test_nb_delta, + test_nb_missed, + ) = compute_nb_correct(self.test_input[:1000]) log_string( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -1149,9 +1154,12 @@ class TaskExpr(Task): nb_total = test_nb_delta.sum() + test_nb_missed for d in range(test_nb_delta.size(0)): - log_string(f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%") - log_string(f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%") - + log_string( + f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%" + ) + log_string( + f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%" + ) ############################################################## # Log a few generated sequences -- 2.20.1