From 240870f5535bac35a08c552108d032854a8e2c38 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 19:25:54 +0300 Subject: [PATCH] Update. --- main.py | 4 +++- quizz_machine.py | 41 ++++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index eb0ef27..0a7be99 100755 --- a/main.py +++ b/main.py @@ -417,7 +417,9 @@ def create_c_quizzes( sum_logits += c_quizzes.size(0) * ave_seq_logproba sum_nb_c_quizzes += c_quizzes.size(0) - nb_correct = quizz_machine.compute_correctness(c_quizzes, models) + nb_correct = quizz_machine.compute_correctness( + c_quizzes, models, both_direction=True + ) if args.dirty_debug: nb_correct = torch.randint( diff --git a/quizz_machine.py b/quizz_machine.py index d591d79..198d279 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -286,7 +286,9 @@ class QuizzMachine: return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1) - def compute_correctness(self, c_quizzes, models_for_validation): + def compute_correctness( + self, c_quizzes, models_for_validation, both_direction=True + ): reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) @@ -313,25 +315,30 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - reversed_result = reversed_c_quizzes.clone() + if both_direction: + reversed_result = reversed_c_quizzes.clone() - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=reversed_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=True, - # progress_bar_desc="solving reversed c_quizzes", - device=self.device, - ) + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=reversed_result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=True, + # progress_bar_desc="solving reversed c_quizzes", + device=self.device, + ) - reversed_correct = ( - (reversed_c_quizzes == reversed_result).long().min(dim=-1).values - ) + reversed_correct = ( + (reversed_c_quizzes == reversed_result).long().min(dim=-1).values + ) + + correct *= reversed_correct + + # endif - nb_correct += correct * reversed_correct + nb_correct += correct return nb_correct -- 2.20.1