parser.add_argument("--generation_temperature", type=float, default=2)
-parser.add_argument("--c_quiz_validation_mode", type=str, default="proba")
+parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
parser.add_argument("--dirty_debug", action="store_true", default=False)
seq_logproba = torch.zeros(nb, device=self.device)
- # First, we generate the answer at high temperature
-
- c_quizzes[:, 0] = self.token_backward
- c_quizzes[:, 1 + self.answer_len] = self.token_backward
+ c_quizzes[:, 0] = self.token_forward
+ c_quizzes[:, 1 + self.prompt_len] = self.token_forward
masked_inplace_autoregression(
model=model_for_generation,
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- # Then, we generate the prompt at low temperature
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=False,
device=self.device,
)
- # Then we return the quizz, and re-generate the response, now
- # at low temperature
-
- c_quizzes = self.reverse_time(c_quizzes)
-
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,