Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 09:50:37 +0000 (12:50 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 09:50:37 +0000 (12:50 +0300)
quizz_machine.py

index 5f19998..1a20563 100755 (executable)
@@ -395,10 +395,15 @@ class QuizzMachine:
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
-        ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
-        ar_mask_solve = 1 - ar_mask_prompt
-        seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+        c_quizzes[:, 0] = self.token_forward
+
+        ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
+        ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
+        ar_mask_second = 1 - ar_mask_first
+        ar_mask_first[:, 0] = 0
+        ar_mask_second[:, 0] = 0
+
+        seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
 
         if reverse_cleanup:
             warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
@@ -416,7 +421,7 @@ class QuizzMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_prompt,
+            ar_mask=ar_mask_first,
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=False,
@@ -431,7 +436,7 @@ class QuizzMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=ar_mask_solve,
+            ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=True,
@@ -440,11 +445,12 @@ class QuizzMachine:
 
         if reverse_cleanup:
             c_quizzes = self.reverse_time(c_quizzes)
+
             masked_inplace_autoregression(
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=ar_mask_solve,
+                ar_mask=ar_mask_second,
                 seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=True,
@@ -452,11 +458,12 @@ class QuizzMachine:
             )
 
             c_quizzes = self.reverse_time(c_quizzes)
+
             masked_inplace_autoregression(
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=ar_mask_solve,
+                ar_mask=ar_mask_second,
                 seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=True,