Update.
[culture.git] / quizz_machine.py
index 90f288e..6e57fb4 100755 (executable)
@@ -202,6 +202,7 @@ class QuizzMachine:
         problem,
         nb_train_samples,
         nb_test_samples,
+        back_accuracy,
         batch_size,
         result_dir,
         logger,
@@ -215,6 +216,7 @@ class QuizzMachine:
         self.nb_token_values = v + 2
 
         self.problem = problem
+        self.back_accuracy = back_accuracy
         self.batch_size = batch_size
         self.device = device
         self.logger = logger
@@ -308,7 +310,6 @@ class QuizzMachine:
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input):
-            input = input[:nmax]
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -325,18 +326,38 @@ class QuizzMachine:
                 device=self.device,
             )
 
-            nb_total = input.size(0)
-            nb_correct = (input == result).long().min(dim=1).values.sum()
+            if self.back_accuracy:
+                n_forward = input[:, 0] == self.token_forward
+                nb_total = input[n_forward].size(0)
+                nb_correct = (
+                    (input[n_forward] == result[n_forward])
+                    .long()
+                    .min(dim=1)
+                    .values.sum()
+                )
+
+                n_backward = input[:, 0] == self.token_backward
+                back_input = self.reverse_time(result[n_backward])
+                if back_input.size(0) > 0:
+                    back_input[:, 2 + self.prompt_len :] = input[
+                        n_backward, 2 + self.prompt_len :
+                    ]
+                    back_nb_total, back_nb_correct = compute_accuracy(back_input)
+                    nb_total += back_nb_total
+                    nb_correct += back_nb_correct
+            else:
+                nb_total = input.size(0)
+                nb_correct = (input == result).long().min(dim=1).values.sum()
 
             return nb_total, nb_correct
 
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"