From 60bf08d4197f2dd3a58bd900401c11d47225b0df Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 12:06:31 +0300 Subject: [PATCH] Update. --- main.py | 3 ++- quizz_machine.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index fd8ab41..67c57c0 100755 --- a/main.py +++ b/main.py @@ -437,7 +437,8 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nv = [recorded[n][-1].size(0) for n in recorded.keys()] + nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) + nv = " ".join([str(x.item()) for x in nv]) log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}") diff --git a/quizz_machine.py b/quizz_machine.py index 806dde7..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,11 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10 + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 # warnings.warn("noise injection", RuntimeWarning) # noise_std = torch.rand(1).item() -- 2.39.5