X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=198d279f3d8f9da8d0b602e622d3a8f690b907fd;hb=240870f5535bac35a08c552108d032854a8e2c38;hp=d591d79c53de0f87944eda6d4f7b1527769d5190;hpb=fcb71a73da3a27f81383e3000b9ad1ee8da45926;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index d591d79..198d279 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -286,7 +286,9 @@ class QuizzMachine: return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1) - def compute_correctness(self, c_quizzes, models_for_validation): + def compute_correctness( + self, c_quizzes, models_for_validation, both_direction=True + ): reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) @@ -313,25 +315,30 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - reversed_result = reversed_c_quizzes.clone() + if both_direction: + reversed_result = reversed_c_quizzes.clone() - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=reversed_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=True, - # progress_bar_desc="solving reversed c_quizzes", - device=self.device, - ) + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=reversed_result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=True, + # progress_bar_desc="solving reversed c_quizzes", + device=self.device, + ) - reversed_correct = ( - (reversed_c_quizzes == reversed_result).long().min(dim=-1).values - ) + reversed_correct = ( + (reversed_c_quizzes == reversed_result).long().min(dim=-1).values + ) + + correct *= reversed_correct + + # endif - nb_correct += correct * reversed_correct + nb_correct += correct return nb_correct