Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 10:07:46 +0000 (13:07 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 10:07:46 +0000 (13:07 +0300)
quizz_machine.py

index 1a20563..5807b66 100755 (executable)
@@ -395,8 +395,6 @@ class QuizzMachine:
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        c_quizzes[:, 0] = self.token_forward
-
         ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
         ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
         ar_mask_second = 1 - ar_mask_first
@@ -406,16 +404,13 @@ class QuizzMachine:
         seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
 
         if reverse_cleanup:
-            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
             temperature = 10.0
         else:
             temperature = 1.0
 
-        # warnings.warn("noise injection", RuntimeWarning)
-        # noise_std = torch.rand(1).item()
-        # self.logger(f"{noise_std=}")
+        # First, we generate the answer at high temperature
 
-        # mygpt.set_noise_injection(model_for_generation, noise_std)
+        c_quizzes[:, 0] = self.token_backward
 
         masked_inplace_autoregression(
             model=model_for_generation,
@@ -428,10 +423,10 @@ class QuizzMachine:
             device=self.device,
         )
 
-        # mygpt.set_noise_injection(model_for_generation, 0.0)
-
         ave_seq_logproba = seq_logproba.mean()
 
+        # Then, we generate the prompt deterministically
+
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,
@@ -443,31 +438,20 @@ class QuizzMachine:
             device=self.device,
         )
 
-        if reverse_cleanup:
-            c_quizzes = self.reverse_time(c_quizzes)
+        # Then we return the quizz, and re-generate the response, now
+        # deterministically
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_second,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
-
-            c_quizzes = self.reverse_time(c_quizzes)
+        c_quizzes = self.reverse_time(c_quizzes)
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_second,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_second,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=True,
+            device=self.device,
+        )
 
         return c_quizzes, seq_logproba.mean()