X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quiz_machine.py;h=34c09a710057c01f5d7b0e4a468c01dffb97cf13;hb=1fe0529f4ac7cacfa8fced2885b7da1e639962a9;hp=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hpb=693af34e144cd20d2dde6a508a190d49c1a76c7f;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index cdfba85..34c09a7 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -285,6 +285,11 @@ class QuizMachine: predicted_answers, ) + def vocabulary_size(self): + return self.nb_token_values + + ###################################################################### + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} if split == "train": @@ -324,8 +329,7 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 @@ -400,6 +404,8 @@ class QuizMachine: return main_test_accuracy + ###################################################################### + def renew_w_quizzes(self, model, nb, for_train=True): input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) @@ -408,13 +414,17 @@ class QuizMachine: self.reverse_random_half_in_place(fresh_w_quizzes) input[-nb:] = fresh_w_quizzes.to(self.device) + ###################################################################### + def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: self.train_c_quizzes.append(new_c_quizzes) else: self.test_c_quizzes.append(new_c_quizzes) - def logproba_solution(self, models, c_quizzes): + ###################################################################### + + def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) for model in models: