X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=6f7492de907f7257888a46a93dad6f37bbdda964;hb=60bf08d4197f2dd3a58bd900401c11d47225b0df;hp=806dde718dbf36fec9eda46cd6bea3da25ff2e86;hpb=51540cefc448684d5086297d23e9a1805da4d405;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 806dde7..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,11 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10 + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 # warnings.warn("noise injection", RuntimeWarning) # noise_std = torch.rand(1).item()