X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=697f27ece4e353fe8a264657013a7fe2f693a630;hb=c9c018e4c19ce92892d7652082fb90719d57441c;hp=d591d79c53de0f87944eda6d4f7b1527769d5190;hpb=fcb71a73da3a27f81383e3000b9ad1ee8da45926;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index d591d79..697f27e 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -139,6 +139,7 @@ class QuizzMachine: self.train_w_quizzes = self.problem.generate_token_sequences( nb_train_samples ).to(device) + self.test_w_quizzes = self.problem.generate_token_sequences(nb_test_samples).to( device ) @@ -286,7 +287,9 @@ class QuizzMachine: return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1) - def compute_correctness(self, c_quizzes, models_for_validation): + def compute_correctness( + self, c_quizzes, models_for_validation, both_directions=True + ): reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) @@ -313,25 +316,30 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - reversed_result = reversed_c_quizzes.clone() + if both_directions: + reversed_result = reversed_c_quizzes.clone() - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=reversed_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=True, - # progress_bar_desc="solving reversed c_quizzes", - device=self.device, - ) + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=reversed_result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=True, + # progress_bar_desc="solving reversed c_quizzes", + device=self.device, + ) - reversed_correct = ( - (reversed_c_quizzes == reversed_result).long().min(dim=-1).values - ) + reversed_correct = ( + (reversed_c_quizzes == reversed_result).long().min(dim=-1).values + ) + + correct *= reversed_correct + + # endif - nb_correct += correct * reversed_correct + nb_correct += correct return nb_correct