X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=6f7492de907f7257888a46a93dad6f37bbdda964;hb=60bf08d4197f2dd3a58bd900401c11d47225b0df;hp=eae256ba8fe0ee54b8dfcd1888df17f99526febf;hpb=66d210bd5e04ae58f9e1495df77f1f975ee99c56;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index eae256b..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,11 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10 + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 # warnings.warn("noise injection", RuntimeWarning) # noise_std = torch.rand(1).item() @@ -471,16 +474,18 @@ class QuizzMachine: models, mode, min_ave_seq_logproba, + reverse_cleanup, n_epoch, result_dir, ): model_for_generation = Gang(models, nb_models_for_generation, mode) models_for_validation = models return self.create_c_quizzes( - nb, - model_for_generation, - models_for_validation, - min_ave_seq_logproba, - n_epoch, - result_dir, + nb=nb, + model_for_generation=model_for_generation, + models_for_validation=models_for_validation, + min_ave_seq_logproba=min_ave_seq_logproba, + reverse_cleanup=reverse_cleanup, + n_epoch=n_epoch, + result_dir=result_dir, )