From fcb71a73da3a27f81383e3000b9ad1ee8da45926 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 14:09:03 +0300 Subject: [PATCH 1/1] Update. --- main.py | 5 ++++- quizz_machine.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 10c7b49..eb0ef27 100755 --- a/main.py +++ b/main.py @@ -435,13 +435,16 @@ def create_c_quizzes( f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" ) - # ------------------------------------------------------------ + # store the new c_quizzes which have been validated new_c_quizzes = valid_c_quizzes(recorded, standard_validity) quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + # save a bunch of images to investigate what quizzes with a + # certain nb of correct predictions look like + for n in range(len(models) + 1): s = ( "_validated" diff --git a/quizz_machine.py b/quizz_machine.py index 88f2c1c..d591d79 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -398,4 +398,16 @@ class QuizzMachine: device=self.device, ) + c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) + return c_quizzes, seq_logproba.mean() -- 2.20.1