From 1fe0529f4ac7cacfa8fced2885b7da1e639962a9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 10 Jul 2024 16:40:51 +0200 Subject: [PATCH] Update. --- main.py | 2 +- quiz_machine.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index d400ab1..d2ac744 100755 --- a/main.py +++ b/main.py @@ -390,7 +390,7 @@ def create_c_quizzes( c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] if c_quizzes.size(0) > 0: - logproba = quiz_machine.logproba_solution(models, c_quizzes) + logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) for l in logproba: s = " ".join([str(x.item()) for x in l]) logp_file.write(s + "\n") diff --git a/quiz_machine.py b/quiz_machine.py index cdfba85..34c09a7 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -285,6 +285,11 @@ class QuizMachine: predicted_answers, ) + def vocabulary_size(self): + return self.nb_token_values + + ###################################################################### + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} if split == "train": @@ -324,8 +329,7 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 @@ -400,6 +404,8 @@ class QuizMachine: return main_test_accuracy + ###################################################################### + def renew_w_quizzes(self, model, nb, for_train=True): input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) @@ -408,13 +414,17 @@ class QuizMachine: self.reverse_random_half_in_place(fresh_w_quizzes) input[-nb:] = fresh_w_quizzes.to(self.device) + ###################################################################### + def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: self.train_c_quizzes.append(new_c_quizzes) else: self.test_c_quizzes.append(new_c_quizzes) - def logproba_solution(self, models, c_quizzes): + ###################################################################### + + def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) for model in models: -- 2.20.1