Update.
[culture.git] / quizz_machine.py
index 84bb558..6f7492d 100755 (executable)
@@ -312,9 +312,7 @@ class QuizzMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
-    def comput_correctness(self, c_quizzes, models_for_validation):
-        # Create the reverse quizzes
-
+    def reverse_time(self, c_quizzes):
         token_forward, token_backward = self.problem.direction_tokens()
 
         l = (c_quizzes.size(1) - 1) // 2
@@ -322,9 +320,11 @@ class QuizzMachine:
         direction = self.problem.token_forward * (
             direction == self.problem.token_backward
         ) + self.problem.token_backward * (direction == self.problem.token_forward)
-        reverse_c_quizzes = torch.cat(
-            [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
-        )
+
+        return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
+
+    def comput_correctness(self, c_quizzes, models_for_validation):
+        reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
@@ -350,12 +350,12 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reverse_result = reverse_c_quizzes.clone()
+            reversed_result = reversed_c_quizzes.clone()
 
             masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
-                input=reverse_result,
+                input=reversed_result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba,
                 temperature=1.0,
@@ -364,17 +364,19 @@ class QuizzMachine:
                 device=self.device,
             )
 
-            reverse_correct = (
-                (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
+            reversed_correct = (
+                (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
             )
 
-            nb_correct += correct * reverse_correct
+            nb_correct += correct * reversed_correct
 
         return nb_correct
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
+    def generate_quizzes(
+        self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
+    ):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
@@ -384,11 +386,17 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        warnings.warn("noise injection", RuntimeWarning)
-        temperature = 1
-        noise_std = torch.rand(1).item()
-        self.logger(f"{noise_std=}")
-        mygpt.set_noise_injection(model_for_generation, noise_std)
+        if reverse_cleanup:
+            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+            temperature = 10.0
+        else:
+            temperature = 1.0
+
+        # warnings.warn("noise injection", RuntimeWarning)
+        # noise_std = torch.rand(1).item()
+        # self.logger(f"{noise_std=}")
+
+        # mygpt.set_noise_injection(model_for_generation, noise_std)
 
         masked_inplace_autoregression(
             model=model_for_generation,
@@ -402,6 +410,8 @@ class QuizzMachine:
             device=self.device,
         )
 
+        # mygpt.set_noise_injection(model_for_generation, 0.0)
+
         ave_seq_logproba = seq_logproba.mean()
 
         masked_inplace_autoregression(
@@ -416,7 +426,19 @@ class QuizzMachine:
             device=self.device,
         )
 
-        mygpt.set_noise_injection(model_for_generation, 0.0)
+        if reverse_cleanup:
+            c_quizzes = self.reverse_time(c_quizzes)
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                # progress_bar_desc="sampling c_quizzes",
+                device=self.device,
+            )
 
         return c_quizzes, seq_logproba.mean()
 
@@ -428,11 +450,15 @@ class QuizzMachine:
         model_for_generation,
         models_for_validation,
         min_ave_seq_logproba,
+        reverse_cleanup,
         n_epoch,
         result_dir,
     ):
         c_quizzes, ave_seq_logproba = self.generate_quizzes(
-            nb, model_for_generation, min_ave_seq_logproba
+            nb,
+            model_for_generation=model_for_generation,
+            min_ave_seq_logproba=min_ave_seq_logproba,
+            reverse_cleanup=reverse_cleanup,
         )
 
         nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
@@ -448,16 +474,18 @@ class QuizzMachine:
         models,
         mode,
         min_ave_seq_logproba,
+        reverse_cleanup,
         n_epoch,
         result_dir,
     ):
         model_for_generation = Gang(models, nb_models_for_generation, mode)
         models_for_validation = models
         return self.create_c_quizzes(
-            nb,
-            model_for_generation,
-            models_for_validation,
-            min_ave_seq_logproba,
-            n_epoch,
-            result_dir,
+            nb=nb,
+            model_for_generation=model_for_generation,
+            models_for_validation=models_for_validation,
+            min_ave_seq_logproba=min_ave_seq_logproba,
+            reverse_cleanup=reverse_cleanup,
+            n_epoch=n_epoch,
+            result_dir=result_dir,
         )