X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quizz_machine.py;h=0d6d8f57cba918235f951333a64cb1a4c44133d2;hb=d283cd3d46a6323fec4c6a0970ac71e553e4a486;hp=5807b660c1fb2a3e291d41661b01d94e93a52d1f;hpb=a8e608a50b84583ad624cdf69d7b34699557235b;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 5807b66..0d6d8f5 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -333,7 +333,7 @@ class QuizzMachine: ) def compute_correctness( - self, c_quizzes, models_for_validation, both_directions=True + self, c_quizzes, models_for_validation, both_directions=False ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -390,7 +390,7 @@ class QuizzMachine: ############################################################### - def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False): + def generate_quizzes(self, nb, model_for_generation): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) @@ -403,10 +403,7 @@ class QuizzMachine: seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) - if reverse_cleanup: - temperature = 10.0 - else: - temperature = 1.0 + temperature = 10.0 # First, we generate the answer at high temperature @@ -433,7 +430,7 @@ class QuizzMachine: input=c_quizzes, ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=temperature, + temperature=1.0, deterministic_synthesis=True, device=self.device, )