X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=eab41dc2a52275ccdae06326ef658503e0b1dd4e;hb=a10a4da02cdf8ea87a63e8d74381551e60cf7bd8;hp=8ab5696c8ae381fa419f15e097f926adad3faf93;hpb=a86dff174205c38d8e90d0d89ea399a6afb36359;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 8ab5696..eab41dc 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -241,7 +241,7 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -266,7 +266,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -368,16 +368,12 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( model.test_w_quizzes[:nmax], log_prefix="test" @@ -388,7 +384,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -416,11 +412,17 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: