nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
- c_quizzes[:, 0] = self.token_forward
-
ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
ar_mask_second = 1 - ar_mask_first
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
if reverse_cleanup:
- warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
temperature = 10.0
else:
temperature = 1.0
- # warnings.warn("noise injection", RuntimeWarning)
- # noise_std = torch.rand(1).item()
- # self.logger(f"{noise_std=}")
+ # First, we generate the answer at high temperature
- # mygpt.set_noise_injection(model_for_generation, noise_std)
+ c_quizzes[:, 0] = self.token_backward
masked_inplace_autoregression(
model=model_for_generation,
device=self.device,
)
- # mygpt.set_noise_injection(model_for_generation, 0.0)
-
ave_seq_logproba = seq_logproba.mean()
+ # Then, we generate the prompt deterministically
+
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
device=self.device,
)
- if reverse_cleanup:
- c_quizzes = self.reverse_time(c_quizzes)
+ # Then we return the quizz, and re-generate the response, now
+ # deterministically
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask_second,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=True,
- device=self.device,
- )
-
- c_quizzes = self.reverse_time(c_quizzes)
+ c_quizzes = self.reverse_time(c_quizzes)
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask_second,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=True,
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_second,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ device=self.device,
+ )
return c_quizzes, seq_logproba.mean()