Update.
[culture.git] / quizz_machine.py
index c5da586..62ae8ce 100755 (executable)
@@ -346,10 +346,6 @@ class QuizzMachine:
                     .item()
                 )
 
-                self.logger(
-                    f"back_accuracy {n_epoch=} {model.id=} {nb_correct=} {nb_total=}"
-                )
-
                 n_backward = input[:, 0] == self.token_backward
                 back_input = self.reverse_time(result[n_backward])
 
@@ -358,18 +354,25 @@ class QuizzMachine:
                         n_backward, 1 : 1 + self.answer_len
                     ]
                     back_nb_total, back_nb_correct = compute_accuracy(back_input)
+
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
                     self.logger(
-                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct=} {back_nb_total=}"
+                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}"
                     )
+
                     nb_total += back_nb_total
                     nb_correct += back_nb_correct
+                else:
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
 
             else:
                 nb_total = input.size(0)
                 nb_correct = (input == result).long().min(dim=1).values.sum()
 
-            exit(0)
-
             return nb_total, nb_correct
 
         train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])