X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=470b095ce38ec23d6a034ee3a338d87b7b1e9b52;hb=08d4ba04f038318080fc2815d85843c4873c896f;hp=0d6d8f57cba918235f951333a64cb1a4c44133d2;hpb=d283cd3d46a6323fec4c6a0970ac71e553e4a486;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 0d6d8f5..470b095 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -29,8 +29,6 @@ def one_batch_masked_inplace_autoregression( seq_logproba, temperature=1.0, deterministic_synthesis=False, - forbidden_tokens=None, - forced_biases=None, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -45,12 +43,6 @@ def one_batch_masked_inplace_autoregression( logits = (logits / temperature).log_softmax(dim=-1) - if forbidden_tokens is not None: - logits = logits.masked_fill(forbidden_tokens, float("-inf")) - - if forced_biases is not None: - logits = logits + forced_biases[None, :] - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -104,8 +96,6 @@ def masked_inplace_autoregression( seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=deterministic_synthesis, - forbidden_tokens=forbidden_tokens, - forced_biases=logit_biases, ) model.train(t) @@ -170,7 +160,6 @@ class QuizzMachine: ) def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False): - print(f"DEBUG {quizzes.size()=}") l = (quizzes.size(1) - 1) // 2 forward = (quizzes[:, 0] == self.token_forward).long() backward = (quizzes[:, 0] == self.token_backward).long() @@ -338,7 +327,11 @@ class QuizzMachine: reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) - seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + seq_logproba = torch.zeros( + c_quizzes.size(0), + max([m.id for m in models_for_validation]) + 1, + device=self.device, + ) # Check how many of models can solve the quizzes in both directions @@ -347,12 +340,14 @@ class QuizzMachine: for model in models_for_validation: result = c_quizzes.clone() + seq_logproba[...] = 0.0 + masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving c_quizzes", @@ -369,7 +364,7 @@ class QuizzMachine: batch_size=self.batch_size, input=reversed_result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving reversed c_quizzes", @@ -386,7 +381,7 @@ class QuizzMachine: nb_correct += correct - return nb_correct + return nb_correct, seq_logproba ############################################################### @@ -401,7 +396,7 @@ class QuizzMachine: ar_mask_first[:, 0] = 0 ar_mask_second[:, 0] = 0 - seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) + seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device) temperature = 10.0 @@ -420,8 +415,6 @@ class QuizzMachine: device=self.device, ) - ave_seq_logproba = seq_logproba.mean() - # Then, we generate the prompt deterministically masked_inplace_autoregression( @@ -451,4 +444,4 @@ class QuizzMachine: device=self.device, ) - return c_quizzes, seq_logproba.mean() + return c_quizzes