Update.
[culture.git] / quizz_machine.py
index 5f19998..0d6d8f5 100755 (executable)
@@ -333,7 +333,7 @@ class QuizzMachine:
         )
 
     def compute_correctness(
         )
 
     def compute_correctness(
-        self, c_quizzes, models_for_validation, both_directions=True
+        self, c_quizzes, models_for_validation, both_directions=False
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
@@ -390,77 +390,65 @@ class QuizzMachine:
 
     ###############################################################
 
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
+    def generate_quizzes(self, nb, model_for_generation):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
-        ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
-        ar_mask_solve = 1 - ar_mask_prompt
-        seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+        ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
+        ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
+        ar_mask_second = 1 - ar_mask_first
+        ar_mask_first[:, 0] = 0
+        ar_mask_second[:, 0] = 0
 
 
-        if reverse_cleanup:
-            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
-            temperature = 10.0
-        else:
-            temperature = 1.0
+        seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
+
+        temperature = 10.0
 
 
-        # warnings.warn("noise injection", RuntimeWarning)
-        # noise_std = torch.rand(1).item()
-        # self.logger(f"{noise_std=}")
+        # First, we generate the answer at high temperature
 
 
-        # mygpt.set_noise_injection(model_for_generation, noise_std)
+        c_quizzes[:, 0] = self.token_backward
 
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
 
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_prompt,
+            ar_mask=ar_mask_first,
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=False,
             device=self.device,
         )
 
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=False,
             device=self.device,
         )
 
-        # mygpt.set_noise_injection(model_for_generation, 0.0)
-
         ave_seq_logproba = seq_logproba.mean()
 
         ave_seq_logproba = seq_logproba.mean()
 
+        # Then, we generate the prompt deterministically
+
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_solve,
+            ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
             seq_logproba=seq_logproba,
-            temperature=temperature,
+            temperature=1.0,
             deterministic_synthesis=True,
             device=self.device,
         )
 
             deterministic_synthesis=True,
             device=self.device,
         )
 
-        if reverse_cleanup:
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
+        # Then we return the quizz, and re-generate the response, now
+        # deterministically
 
 
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
+        c_quizzes = self.reverse_time(c_quizzes)
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_second,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=True,
+            device=self.device,
+        )
 
         return c_quizzes, seq_logproba.mean()
 
         return c_quizzes, seq_logproba.mean()