+ reversed_correct = (
+ (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+ )
+
+ nb_correct += correct * reversed_correct
+
+ return nb_correct
+
+ ###############################################################
+
+ def generate_quizzes(
+ self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
+ ):
+ c_quizzes = torch.empty(
+ nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+ )
+
+ ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
+ ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
+ ar_mask_solve = 1 - ar_mask_prompt
+ seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+
+ if reverse_cleanup:
+ warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+ temperature = 10.0
+ else:
+ temperature = 1.0
+
+ # warnings.warn("noise injection", RuntimeWarning)
+ # noise_std = torch.rand(1).item()
+ # self.logger(f"{noise_std=}")
+
+ # mygpt.set_noise_injection(model_for_generation, noise_std)
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_prompt,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ # mygpt.set_noise_injection(model_for_generation, 0.0)
+
+ ave_seq_logproba = seq_logproba.mean()
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ if reverse_cleanup:
+ c_quizzes = self.reverse_time(c_quizzes)
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,