)
def compute_correctness(
- self, c_quizzes, models_for_validation, both_directions=True
+ self, c_quizzes, models_for_validation, both_directions=False
):
reversed_c_quizzes = self.reverse_time(c_quizzes)
###############################################################
- def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
+ def generate_quizzes(self, nb, model_for_generation):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
- if reverse_cleanup:
- temperature = 10.0
- else:
- temperature = 1.0
+ temperature = 10.0
# First, we generate the answer at high temperature
input=c_quizzes,
ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
- temperature=temperature,
+ temperature=1.0,
deterministic_synthesis=True,
device=self.device,
)