Update.
[culture.git] / quizz_machine.py
index 90f288e..62ae8ce 100755 (executable)
@@ -122,12 +122,13 @@ class QuizzMachine:
         forward_to_backward = torch.cat(
             [
                 quizzes[:, 0:1],
         forward_to_backward = torch.cat(
             [
                 quizzes[:, 0:1],
-                quizzes[:, 2 + self.prompt_len :],
-                quizzes[:, 1 + self.prompt_len : 2 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
+                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
                 quizzes[:, 1 : 1 + self.prompt_len],
             ],
             dim=1,
         )
                 quizzes[:, 1 : 1 + self.prompt_len],
             ],
             dim=1,
         )
+
         forward_to_backward[:, 0] = self.token_backward
         forward_to_backward[:, 1 + self.answer_len] = self.token_backward
 
         forward_to_backward[:, 0] = self.token_backward
         forward_to_backward[:, 1 + self.answer_len] = self.token_backward
 
@@ -202,6 +203,7 @@ class QuizzMachine:
         problem,
         nb_train_samples,
         nb_test_samples,
         problem,
         nb_train_samples,
         nb_test_samples,
+        back_accuracy,
         batch_size,
         result_dir,
         logger,
         batch_size,
         result_dir,
         logger,
@@ -215,6 +217,7 @@ class QuizzMachine:
         self.nb_token_values = v + 2
 
         self.problem = problem
         self.nb_token_values = v + 2
 
         self.problem = problem
+        self.back_accuracy = back_accuracy
         self.batch_size = batch_size
         self.device = device
         self.logger = logger
         self.batch_size = batch_size
         self.device = device
         self.logger = logger
@@ -232,14 +235,14 @@ class QuizzMachine:
 
         if result_dir is not None:
             self.save_quizzes(
 
         if result_dir is not None:
             self.save_quizzes(
-                result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
+                result_dir,
+                "culture_w_quizzes",
+                self.train_w_quizzes[:72],
+                prediction=True,
             )
 
             )
 
-            # toto = self.reverse_time(self.train_w_quizzes[:72])
-            # self.save_quizzes(result_dir, "toto", toto)
-            # exit(0)
-
     def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
     def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+        quizzes = quizzes.clone()
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
         backward = quizzes[ib]
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
         backward = quizzes[ib]
@@ -308,7 +311,6 @@ class QuizzMachine:
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input):
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input):
-            input = input[:nmax]
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -325,18 +327,61 @@ class QuizzMachine:
                 device=self.device,
             )
 
                 device=self.device,
             )
 
-            nb_total = input.size(0)
-            nb_correct = (input == result).long().min(dim=1).values.sum()
+            if self.back_accuracy:
+                # If back_accuracy is True, we compute the accuracy on
+                # the backward quizzes not by counting how many time
+                # the real prompt A is equal to the reconstructed
+                # prompt A*, but how many time the answers B* computed
+                # from A* is equal to the correct answer. So we look
+                # for the accuracy of A->B*=B for the forward, but for
+                # the backward we look at B->A*->B*=B instead of B->A*=A
+
+                n_forward = input[:, 0] == self.token_forward
+                nb_total = input[n_forward].size(0)
+                nb_correct = (
+                    (input[n_forward] == result[n_forward])
+                    .long()
+                    .min(dim=1)
+                    .values.sum()
+                    .item()
+                )
+
+                n_backward = input[:, 0] == self.token_backward
+                back_input = self.reverse_time(result[n_backward])
+
+                if back_input.size(0) > 0:
+                    back_input[:, 2 + self.prompt_len :] = input[
+                        n_backward, 1 : 1 + self.answer_len
+                    ]
+                    back_nb_total, back_nb_correct = compute_accuracy(back_input)
+
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
+                    self.logger(
+                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}"
+                    )
+
+                    nb_total += back_nb_total
+                    nb_correct += back_nb_correct
+                else:
+                    self.logger(
+                        f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+                    )
+
+            else:
+                nb_total = input.size(0)
+                nb_correct = (input == result).long().min(dim=1).values.sum()
 
             return nb_total, nb_correct
 
 
             return nb_total, nb_correct
 
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
 
         self.logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
 
         self.logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
 
         self.logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"