return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
- def compute_correctness(self, c_quizzes, models_for_validation):
+ def compute_correctness(
+ self, c_quizzes, models_for_validation, both_direction=True
+ ):
reversed_c_quizzes = self.reverse_time(c_quizzes)
ar_mask = self.make_ar_mask(c_quizzes)
correct = (c_quizzes == result).long().min(dim=-1).values
- reversed_result = reversed_c_quizzes.clone()
+ if both_direction:
+ reversed_result = reversed_c_quizzes.clone()
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=reversed_result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=True,
- # progress_bar_desc="solving reversed c_quizzes",
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=reversed_result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=True,
+ # progress_bar_desc="solving reversed c_quizzes",
+ device=self.device,
+ )
- reversed_correct = (
- (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
- )
+ reversed_correct = (
+ (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+ )
+
+ correct *= reversed_correct
+
+ # endif
- nb_correct += correct * reversed_correct
+ nb_correct += correct
return nb_correct
device=self.device,
)
+ c_quizzes = self.reverse_time(c_quizzes)
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ device=self.device,
+ )
+
return c_quizzes, seq_logproba.mean()