From bc9ed7c97f932ebc81f573c4ecd7207b82a011d7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 2 Jul 2024 09:55:03 +0300 Subject: [PATCH] Update. --- main.py | 26 ++++++++++---------------- quizz_machine.py | 1 + sky.py | 6 ++++++ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 714327d..d412e6c 100755 --- a/main.py +++ b/main.py @@ -79,7 +79,7 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument("--reverse_cleanup", action="store_true", default=False) +parser.add_argument("--reverse_cleanup", action="store_true", default=True) parser.add_argument("--validation_forward_only", action="store_true", default=False) @@ -364,7 +364,7 @@ def run_tests(model, quizz_machine, deterministic_synthesis): nb_test_samples += input.size(0) - main_test_accuracy = quizz_machine.produce_results( + model.main_test_accuracy = quizz_machine.produce_results( n_epoch=n_epoch, model=model, result_dir=args.result_dir, @@ -375,8 +375,6 @@ def run_tests(model, quizz_machine, deterministic_synthesis): log_string(f"test_perplexity {n_epoch} {test_perplexity}") - model.main_test_accuracy = main_test_accuracy - ###################################################################### @@ -397,8 +395,6 @@ def create_c_quizzes( ): recorded = [] - sum_logits, sum_nb_c_quizzes = 0, 0 - nb_to_create = nb_for_train + nb_for_test # ------------------------------------------------------------ @@ -416,9 +412,6 @@ def create_c_quizzes( reverse_cleanup=args.reverse_cleanup, ) - sum_logits += c_quizzes.size(0) * ave_seq_logproba - sum_nb_c_quizzes += c_quizzes.size(0) - nb_correct = quizz_machine.compute_correctness( c_quizzes, models, both_directions=not args.validation_forward_only ) @@ -456,13 +449,14 @@ def create_c_quizzes( else "" ) - quizz_machine.problem.save_quizzes( - valid_c_quizzes(recorded, criteria=lambda nb_correct: nb_correct == n)[:72], - args.result_dir, - f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", - ) + q = valid_c_quizzes(recorded, criteria=lambda nb_correct: nb_correct == n)[:72] - return sum_logits / sum_nb_c_quizzes + if q.size(0) > 0: + quizz_machine.problem.save_quizzes( + q, + args.result_dir, + f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", + ) ###################################################################### @@ -518,7 +512,7 @@ for n_epoch in range(args.nb_epochs): cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") - # replace a fraction of the w_quizzes with a fresh ones + # replace a fraction of the w_quizzes with fresh ones quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: diff --git a/quizz_machine.py b/quizz_machine.py index 4e7576e..697f27e 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -139,6 +139,7 @@ class QuizzMachine: self.train_w_quizzes = self.problem.generate_token_sequences( nb_train_samples ).to(device) + self.test_w_quizzes = self.problem.generate_token_sequences(nb_test_samples).to( device ) diff --git a/sky.py b/sky.py index 1164185..4ca4ba7 100755 --- a/sky.py +++ b/sky.py @@ -157,6 +157,12 @@ class Sky(problem.Problem): ###################################################################### + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1) + return prompts, answers + def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb) -- 2.20.1