Update.
[culture.git] / quizz_machine.py
index 49e7835..c5870d0 100755 (executable)
@@ -378,7 +378,9 @@ class QuizzMachine:
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
+        ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
+        ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
 
         # bracketing of the temperature to get the target logproba
@@ -393,7 +395,7 @@ class QuizzMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=ar_mask,
+                ar_mask=ar_mask_prompt,
                 seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=False,
@@ -403,6 +405,18 @@ class QuizzMachine:
 
             ave_seq_logproba = seq_logproba.mean()
 
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                # progress_bar_desc="sampling c_quizzes",
+                device=self.device,
+            )
+
             # If we do not have target logprobs, get out now
             if min_ave_seq_logproba is None:
                 break