sum_logits += c_quizzes.size(0) * ave_seq_logproba
sum_nb_c_quizzes += c_quizzes.size(0)
- nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
+ nb_correct = quizz_machine.compute_correctness(
+ c_quizzes, models, both_direction=True
+ )
if args.dirty_debug:
nb_correct = torch.randint(
return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
- def compute_correctness(self, c_quizzes, models_for_validation):
+ def compute_correctness(
+ self, c_quizzes, models_for_validation, both_direction=True
+ ):
reversed_c_quizzes = self.reverse_time(c_quizzes)
ar_mask = self.make_ar_mask(c_quizzes)
correct = (c_quizzes == result).long().min(dim=-1).values
- reversed_result = reversed_c_quizzes.clone()
+ if both_direction:
+ reversed_result = reversed_c_quizzes.clone()
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=reversed_result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=True,
- # progress_bar_desc="solving reversed c_quizzes",
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=reversed_result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=True,
+ # progress_bar_desc="solving reversed c_quizzes",
+ device=self.device,
+ )
- reversed_correct = (
- (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
- )
+ reversed_correct = (
+ (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+ )
+
+ correct *= reversed_correct
+
+ # endif
- nb_correct += correct * reversed_correct
+ nb_correct += correct
return nb_correct