From: François Fleuret Date: Sun, 30 Jun 2024 10:10:56 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=3dca75c7144421022e45cea9288cd87957ff5867;p=culture.git Update. --- diff --git a/quizz_machine.py b/quizz_machine.py index c5870d0..18d0e0b 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -379,13 +379,14 @@ class QuizzMachine: ) ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) - ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1 + ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1 ar_mask_solve = 1 - ar_mask_prompt - seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) # bracketing of the temperature to get the target logproba - temperature = 1 + warnings.warn("high temperature!", RuntimeWarning) + temperature = 2 d_temperature = 1 / 3 while True: