Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 21:12:41 +0000 (23:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 21:12:41 +0000 (23:12 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 9d36aba..5c58beb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -90,7 +90,7 @@ parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
 parser.add_argument("--generation_temperature", type=float, default=2)
 
-parser.add_argument("--c_quiz_validation_mode", type=str, default="proba")
+parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
index f66258a..0f834dc 100755 (executable)
@@ -537,10 +537,8 @@ class QuizMachine:
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
-        # First, we generate the answer at high temperature
-
-        c_quizzes[:, 0] = self.token_backward
-        c_quizzes[:, 1 + self.answer_len] = self.token_backward
+        c_quizzes[:, 0] = self.token_forward
+        c_quizzes[:, 1 + self.prompt_len] = self.token_forward
 
         masked_inplace_autoregression(
             model=model_for_generation,
@@ -548,29 +546,11 @@ class QuizMachine:
             input=c_quizzes,
             ar_mask=self.make_ar_mask(c_quizzes, first=True),
             seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # Then, we generate the prompt at low temperature
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
             temperature=1.0,
             deterministic_synthesis=False,
             device=self.device,
         )
 
-        # Then we return the quizz, and re-generate the response, now
-        # at low temperature
-
-        c_quizzes = self.reverse_time(c_quizzes)
-
         masked_inplace_autoregression(
             model=model_for_generation,
             batch_size=self.batch_size,