Update.
[culture.git] / quizz_machine.py
index 5f19998..5807b66 100755 (executable)
@@ -395,72 +395,63 @@ class QuizzMachine:
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
-        ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
-        ar_mask_solve = 1 - ar_mask_prompt
-        seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+        ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
+        ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
+        ar_mask_second = 1 - ar_mask_first
+        ar_mask_first[:, 0] = 0
+        ar_mask_second[:, 0] = 0
+
+        seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
 
         if reverse_cleanup:
-            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
             temperature = 10.0
         else:
             temperature = 1.0
 
-        # warnings.warn("noise injection", RuntimeWarning)
-        # noise_std = torch.rand(1).item()
-        # self.logger(f"{noise_std=}")
+        # First, we generate the answer at high temperature
 
-        # mygpt.set_noise_injection(model_for_generation, noise_std)
+        c_quizzes[:, 0] = self.token_backward
 
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_prompt,
+            ar_mask=ar_mask_first,
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=False,
             device=self.device,
         )
 
-        # mygpt.set_noise_injection(model_for_generation, 0.0)
-
         ave_seq_logproba = seq_logproba.mean()
 
+        # Then, we generate the prompt deterministically
+
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_solve,
+            ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=True,
             device=self.device,
         )
 
-        if reverse_cleanup:
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
+        # Then we return the quizz, and re-generate the response, now
+        # deterministically
 
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
+        c_quizzes = self.reverse_time(c_quizzes)
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_second,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=True,
+            device=self.device,
+        )
 
         return c_quizzes, seq_logproba.mean()