X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=6f7492de907f7257888a46a93dad6f37bbdda964;hb=60bf08d4197f2dd3a58bd900401c11d47225b0df;hp=7b0b877bdc0b04b0eec304850630461878493b96;hpb=5c5668b0e52e2ae579d49ba8a44fafe2339ad8c0;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 7b0b877..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,13 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 + # warnings.warn("noise injection", RuntimeWarning) - temperature = 1 # noise_std = torch.rand(1).item() # self.logger(f"{noise_std=}") @@ -445,11 +450,15 @@ class QuizzMachine: model_for_generation, models_for_validation, min_ave_seq_logproba, + reverse_cleanup, n_epoch, result_dir, ): c_quizzes, ave_seq_logproba = self.generate_quizzes( - nb, model_for_generation, min_ave_seq_logproba + nb, + model_for_generation=model_for_generation, + min_ave_seq_logproba=min_ave_seq_logproba, + reverse_cleanup=reverse_cleanup, ) nb_correct = self.comput_correctness(c_quizzes, models_for_validation) @@ -465,16 +474,18 @@ class QuizzMachine: models, mode, min_ave_seq_logproba, + reverse_cleanup, n_epoch, result_dir, ): model_for_generation = Gang(models, nb_models_for_generation, mode) models_for_validation = models return self.create_c_quizzes( - nb, - model_for_generation, - models_for_validation, - min_ave_seq_logproba, - n_epoch, - result_dir, + nb=nb, + model_for_generation=model_for_generation, + models_for_validation=models_for_validation, + min_ave_seq_logproba=min_ave_seq_logproba, + reverse_cleanup=reverse_cleanup, + n_epoch=n_epoch, + result_dir=result_dir, )