From 08d4ba04f038318080fc2815d85843c4873c896f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 2 Jul 2024 17:27:57 +0300 Subject: [PATCH] Update. --- main.py | 54 +++++++++++++++++++++++++++--------------------- quizz_machine.py | 31 +++++++++++---------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/main.py b/main.py index d63398c..7b8b642 100755 --- a/main.py +++ b/main.py @@ -362,6 +362,10 @@ def run_tests(model, quizz_machine, deterministic_synthesis): nb_test_samples += input.size(0) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string(f"test_perplexity {n_epoch} {test_perplexity}") + model.main_test_accuracy = quizz_machine.produce_results( n_epoch=n_epoch, model=model, @@ -369,10 +373,6 @@ def run_tests(model, quizz_machine, deterministic_synthesis): deterministic_synthesis=deterministic_synthesis, ) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} {test_perplexity}") - ###################################################################### @@ -401,33 +401,41 @@ def create_c_quizzes( nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate ) - while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create: - model_for_generation = models[torch.randint(len(models), (1,))] + file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + with open(file_name, "w") as logp_file: + while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create: + # Select a model at random to generate the new quizzes - c_quizzes, ave_seq_logproba = quizz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - ) + model_for_generation = models[torch.randint(len(models), (1,))] - nb_correct = quizz_machine.compute_correctness( - c_quizzes, models, both_directions=args.both_directions - ) + c_quizzes = quizz_machine.generate_quizzes( + nb_to_create, + model_for_generation=model_for_generation, + ) - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device + nb_correct, seq_logproba = quizz_machine.compute_correctness( + c_quizzes, models, both_directions=args.both_directions ) - recorded.append((c_quizzes, nb_correct)) + for n, l in zip(nb_correct, seq_logproba): + s = " ".join([str(x.item()) for x in l]) + logp_file.write(f"{n} {s}\n") - nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) - nv = " ".join([str(x.item()) for x in nv]) + if args.dirty_debug: + nb_correct = torch.randint( + len(models) + 1, nb_correct.size(), device=c_quizzes.device + ) - nb_validated = valid_c_quizzes(recorded, standard_validity).size(0) + recorded.append((c_quizzes, nb_correct)) - log_string( - f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" - ) + nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) + nv = " ".join([str(x.item()) for x in nv]) + + nb_validated = valid_c_quizzes(recorded, standard_validity).size(0) + + log_string( + f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" + ) # store the new c_quizzes which have been validated diff --git a/quizz_machine.py b/quizz_machine.py index 0d6d8f5..470b095 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -29,8 +29,6 @@ def one_batch_masked_inplace_autoregression( seq_logproba, temperature=1.0, deterministic_synthesis=False, - forbidden_tokens=None, - forced_biases=None, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -45,12 +43,6 @@ def one_batch_masked_inplace_autoregression( logits = (logits / temperature).log_softmax(dim=-1) - if forbidden_tokens is not None: - logits = logits.masked_fill(forbidden_tokens, float("-inf")) - - if forced_biases is not None: - logits = logits + forced_biases[None, :] - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -104,8 +96,6 @@ def masked_inplace_autoregression( seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=deterministic_synthesis, - forbidden_tokens=forbidden_tokens, - forced_biases=logit_biases, ) model.train(t) @@ -170,7 +160,6 @@ class QuizzMachine: ) def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False): - print(f"DEBUG {quizzes.size()=}") l = (quizzes.size(1) - 1) // 2 forward = (quizzes[:, 0] == self.token_forward).long() backward = (quizzes[:, 0] == self.token_backward).long() @@ -338,7 +327,11 @@ class QuizzMachine: reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) - seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + seq_logproba = torch.zeros( + c_quizzes.size(0), + max([m.id for m in models_for_validation]) + 1, + device=self.device, + ) # Check how many of models can solve the quizzes in both directions @@ -347,12 +340,14 @@ class QuizzMachine: for model in models_for_validation: result = c_quizzes.clone() + seq_logproba[...] = 0.0 + masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving c_quizzes", @@ -369,7 +364,7 @@ class QuizzMachine: batch_size=self.batch_size, input=reversed_result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving reversed c_quizzes", @@ -386,7 +381,7 @@ class QuizzMachine: nb_correct += correct - return nb_correct + return nb_correct, seq_logproba ############################################################### @@ -401,7 +396,7 @@ class QuizzMachine: ar_mask_first[:, 0] = 0 ar_mask_second[:, 0] = 0 - seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) + seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device) temperature = 10.0 @@ -420,8 +415,6 @@ class QuizzMachine: device=self.device, ) - ave_seq_logproba = seq_logproba.mean() - # Then, we generate the prompt deterministically masked_inplace_autoregression( @@ -451,4 +444,4 @@ class QuizzMachine: device=self.device, ) - return c_quizzes, seq_logproba.mean() + return c_quizzes -- 2.20.1