ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+ if reverse_cleanup:
+ warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+ temperature = 10.0
+ else:
+ temperature = 1.0
+
# warnings.warn("noise injection", RuntimeWarning)
- temperature = 1
# noise_std = torch.rand(1).item()
# self.logger(f"{noise_std=}")
models,
mode,
min_ave_seq_logproba,
+ reverse_cleanup,
n_epoch,
result_dir,
):
model_for_generation = Gang(models, nb_models_for_generation, mode)
models_for_validation = models
return self.create_c_quizzes(
- nb,
- model_for_generation,
- models_for_validation,
- min_ave_seq_logproba,
- n_epoch,
- result_dir,
+ nb=nb,
+ model_for_generation=model_for_generation,
+ models_for_validation=models_for_validation,
+ min_ave_seq_logproba=min_ave_seq_logproba,
+ reverse_cleanup=reverse_cleanup,
+ n_epoch=n_epoch,
+ result_dir=result_dir,
)