X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=632c9ae5875b4db2ea058375cd9f95e5744bf860;hb=30c76210e3ed2704b2a059208f385cb623c1486d;hp=62ae8ce94af2b09b909a8182ad3a4a0b4709a1c1;hpb=cc241681730e50ad149a68c612e3a06f2d4a71be;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 62ae8ce..632c9ae 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -238,10 +238,17 @@ class QuizzMachine: result_dir, "culture_w_quizzes", self.train_w_quizzes[:72], - prediction=True, + show_to_be_predicted=True, ) - def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False): + def save_quizzes( + self, + result_dir, + filename_prefix, + quizzes, + show_to_be_predicted=False, + mistakes=None, + ): quizzes = quizzes.clone() forward = quizzes[quizzes[:, 0] == self.token_forward] ib = quizzes[:, 0] == self.token_backward @@ -249,9 +256,17 @@ class QuizzMachine: assert forward.size(0) + backward.size(0) == quizzes.size(0) quizzes[ib] = self.reverse_time(quizzes[ib]) - if prediction: - predicted_prompts = ib - predicted_answers = torch.logical_not(ib) + if show_to_be_predicted: + predicted_prompts = ib.long() + predicted_answers = 1 - predicted_prompts + if mistakes is not None: + # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct + predicted_prompts *= mistakes + predicted_answers *= mistakes + else: + # 0/2 ~ not-to-predict / to predict + predicted_prompts *= 2 + predicted_answers *= 2 else: predicted_prompts = None predicted_answers = None @@ -409,11 +424,14 @@ class QuizzMachine: device=self.device, ) + mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1 + self.save_quizzes( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=result[:72], - prediction=True, + show_to_be_predicted=True, + mistakes=mistakes[:72], ) return main_test_accuracy