X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quizz_machine.py;h=6cad6a1c74a5b3fc04d04ce253c65f94286cf06f;hb=db8c21397d370ae16fd6078858c649e2ab14fe4e;hp=18d0e0b49fd6dc48c81ba42db634f6b6d6043875;hpb=3dca75c7144421022e45cea9288cd87957ff5867;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 18d0e0b..6cad6a1 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -383,9 +383,9 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - # bracketing of the temperature to get the target logproba + # bracketing of the temperature to get the target logproba if + # min_ave_seq_logproba is not None - warnings.warn("high temperature!", RuntimeWarning) temperature = 2 d_temperature = 1 / 3