From: François Fleuret Date: Mon, 1 Jul 2024 11:09:03 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=fcb71a73da3a27f81383e3000b9ad1ee8da45926;p=culture.git Update. --- diff --git a/main.py b/main.py index 10c7b49..eb0ef27 100755 --- a/main.py +++ b/main.py @@ -435,13 +435,16 @@ def create_c_quizzes( f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" ) - # ------------------------------------------------------------ + # store the new c_quizzes which have been validated new_c_quizzes = valid_c_quizzes(recorded, standard_validity) quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + # save a bunch of images to investigate what quizzes with a + # certain nb of correct predictions look like + for n in range(len(models) + 1): s = ( "_validated" diff --git a/quizz_machine.py b/quizz_machine.py index 88f2c1c..d591d79 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -398,4 +398,16 @@ class QuizzMachine: device=self.device, ) + c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) + return c_quizzes, seq_logproba.mean()