X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quizz_machine.py;h=6f7492de907f7257888a46a93dad6f37bbdda964;hb=6b4e192557e03528ffd10364123de454aa9c9f08;hp=0dfbfcc96fb592b8faa3995baa90ff96e7a7f5a4;hpb=5e40cc7b666b8ce3532c7e98810192d7fb6c3a4f;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 0dfbfcc..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,13 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 + # warnings.warn("noise injection", RuntimeWarning) - temperature = 1 # noise_std = torch.rand(1).item() # self.logger(f"{noise_std=}") @@ -469,16 +474,18 @@ class QuizzMachine: models, mode, min_ave_seq_logproba, + reverse_cleanup, n_epoch, result_dir, ): model_for_generation = Gang(models, nb_models_for_generation, mode) models_for_validation = models return self.create_c_quizzes( - nb, - model_for_generation, - models_for_validation, - min_ave_seq_logproba, - n_epoch, - result_dir, + nb=nb, + model_for_generation=model_for_generation, + models_for_validation=models_for_validation, + min_ave_seq_logproba=min_ave_seq_logproba, + reverse_cleanup=reverse_cleanup, + n_epoch=n_epoch, + result_dir=result_dir, )