+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_prompt,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ # mygpt.set_noise_injection(model_for_generation, 0.0)
+
+ ave_seq_logproba = seq_logproba.mean()
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ if reverse_cleanup:
+ c_quizzes = self.reverse_time(c_quizzes)
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ return c_quizzes, seq_logproba.mean()
+
+ ######################################################################
+
+ def create_c_quizzes(
+ self,
+ nb,
+ model_for_generation,
+ models_for_validation,
+ min_ave_seq_logproba,
+ reverse_cleanup,
+ n_epoch,
+ result_dir,
+ ):
+ c_quizzes, ave_seq_logproba = self.generate_quizzes(
+ nb,
+ model_for_generation=model_for_generation,
+ min_ave_seq_logproba=min_ave_seq_logproba,
+ reverse_cleanup=reverse_cleanup,
+ )
+
+ nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
+
+ return c_quizzes, nb_correct, ave_seq_logproba
+
+ ######################################################################
+
+ def gang_create_c_quizzes(
+ self,
+ nb,
+ nb_models_for_generation,
+ models,
+ mode,
+ min_ave_seq_logproba,
+ reverse_cleanup,
+ n_epoch,
+ result_dir,
+ ):
+ model_for_generation = Gang(models, nb_models_for_generation, mode)
+ models_for_validation = models
+ return self.create_c_quizzes(
+ nb=nb,
+ model_for_generation=model_for_generation,
+ models_for_validation=models_for_validation,
+ min_ave_seq_logproba=min_ave_seq_logproba,
+ reverse_cleanup=reverse_cleanup,
+ n_epoch=n_epoch,
+ result_dir=result_dir,
+ )