X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quiz_machine.py;h=88fd9f1dfea42ebfdbd99f290243b66d45c8f639;hb=7b716a85786247b292ee71a635c98a18c66b421d;hp=ae146147ca1f753f327943b1494811081d261be1;hpb=4a63b2b44bc08cb04b236b35a3d36aa242912d48;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index ae14614..88fd9f1 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -241,7 +241,7 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -266,7 +266,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -368,11 +368,7 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct @@ -388,7 +384,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -420,21 +416,27 @@ class QuizMachine: def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + model.train(t) return logproba.to("cpu")