Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 5 Jul 2023 07:25:40 +0000 (09:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 5 Jul 2023 07:25:40 +0000 (09:25 +0200)
main.py

diff --git a/main.py b/main.py
index 15e6d99..32447bf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1119,33 +1119,40 @@ class TaskExpr(Task):
                 #######################################################################
                 # Comput predicted vs. true variable values
 
+                nb_delta = torch.zeros(5, dtype=torch.int64)
+                nb_missed = 0
+
                 values_input = expr.extract_results([self.seq2str(s) for s in input])
-                max_input = max([max(x.values()) for x in values_input])
                 values_result = expr.extract_results([self.seq2str(s) for s in result])
-                max_result = max(
-                    [-1 if len(x) == 0 else max(x.values()) for x in values_result]
-                )
-
-                nb_missing = torch.zeros(max_input + 1)
-                nb_predicted = torch.zeros(max_input + 1, max_result + 1)
 
                 for i, r in zip(values_input, values_result):
                     for n, vi in i.items():
                         vr = r.get(n)
                         if vr is None or vr < 0:
-                            nb_missing[vi] += 1
+                            nb_missed += 1
                         else:
-                            nb_predicted[vi, vr] += 1
+                            d = abs(vr-vi)
+                            if d >= nb_delta.size(0):
+                                nb_missed += 1
+                            else:
+                                nb_delta[d] += 1
+
                 ######################################################################
 
-                return nb_total, nb_correct
+                return nb_total, nb_correct, nb_delta, nb_missed
 
-            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+            test_nb_total, test_nb_correct, test_nb_delta, test_nb_missed = compute_nb_correct(self.test_input[:1000])
 
             log_string(
                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
+            nb_total = test_nb_delta.sum() + test_nb_missed
+            for d in range(test_nb_delta.size(0)):
+                log_string(f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%")
+            log_string(f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%")
+
+
             ##############################################################
             # Log a few generated sequences
             input = self.test_input[:10]