X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quiz_machine.py;h=c39bf7adbf2fc30ddb155578ebcd7b29b245ec6f;hb=2f87c91cf606a068de1450d198660de7e44cd356;hp=ae146147ca1f753f327943b1494811081d261be1;hpb=4a63b2b44bc08cb04b236b35a3d36aa242912d48;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index ae14614..c39bf7a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -241,7 +241,7 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -266,7 +266,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -368,11 +368,7 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct @@ -388,7 +384,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -416,25 +412,37 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + model.train(t) return logproba.to("cpu")