From d16410119a4e5c1117f7f0fbbe80e3e54f81f28b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 17:28:40 +0200 Subject: [PATCH] Update. --- main.py | 28 +++++++++++++++++----------- tasks.py | 37 +++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index 3acf595..ca0d152 100755 --- a/main.py +++ b/main.py @@ -879,7 +879,7 @@ def run_tests(model, task, deterministic_synthesis): nb_test_samples += input.size(0) - task.produce_results( + main_test_accuracy = task.produce_results( n_epoch=n_epoch, model=model, result_dir=args.result_dir, @@ -888,8 +888,11 @@ def run_tests(model, task, deterministic_synthesis): ) test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + log_string(f"test_perplexity {n_epoch} {test_perplexity}") + return main_test_accuracy + ###################################################################### @@ -898,16 +901,16 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): one_epoch(model, task, learning_rate) - run_tests(model, task, deterministic_synthesis=False) + test_accuracy = run_tests(model, task, deterministic_synthesis=False) # -------------------------------------------- - if n_epoch >= 3: - nb_required = 100 + if test_accuracy >= 0.8: + nb_for_train, nb_for_test = 1000, 100 kept = [] - while sum([x.size(0) for x in kept]) < nb_required: - new_problems, nb_correct = task.create_new_problems( + while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: + new_quizzes, nb_correct = task.create_new_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, @@ -916,14 +919,17 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): nb_runs=10, ) - to_keep = new_problems[torch.logical_and(nb_correct >= 8, nb_correct < 10)] - log_string(f"keep {to_keep.size(0)} problems") + to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)] + log_string(f"keep {to_keep.size(0)} quizzes") kept.append(to_keep) - new_problems = torch.cat(kept, dim=0)[:nb_required] - task.store_new_problems(new_problems) + new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] + + task.store_new_quizzes(new_quizzes[:nb_for_train], train=True) + task.store_new_quizzes(new_quizzes[nb_for_train:], train=False) + task.save_image( - new_problems[:96], + new_quizzes[:96], args.result_dir, f"world_new_{n_epoch:04d}.png", log_string, diff --git a/tasks.py b/tasks.py index 1a6c415..49b83ec 100755 --- a/tasks.py +++ b/tasks.py @@ -2208,7 +2208,8 @@ class World(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) - logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + main_test_accuracy = test_nb_correct / test_nb_total + logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") ############################## @@ -2227,38 +2228,42 @@ class World(Task): self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger) - def store_new_problems(self, new_problems): - nb_current = self.train_input.size(0) - nb_new = new_problems.size(0) + return main_test_accuracy + + def store_new_quizzes(self, new_quizzes, for_train=True): + input = self.train_input if for_train else self.test_input + + nb_current = input.size(0) + nb_new = new_quizzes.size(0) if nb_new >= nb_current: - self.train_input[...] = new_problems[:nb_current] + input[...] = new_quizzes[:nb_current] else: nb_kept = nb_current - nb_new - self.train_input[:nb_kept] = self.train_input[-nb_kept:].clone() - self.train_input[nb_kept:] = new_problems + input[:nb_kept] = input[-nb_kept:].clone() + input[nb_kept:] = new_quizzes - def create_new_problems(self, n_epoch, result_dir, logger, nb, model, nb_runs): - new_problems = torch.empty( + def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs): + new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(new_problems.size(), 1, device=self.device) + ar_mask = torch.full(new_quizzes.size(), 1, device=self.device) masked_inplace_autoregression( model, self.batch_size, - new_problems, + new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new problems", + progress_bar_desc="new quizzes", device=self.device, ) nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64) for n in tqdm.tqdm( - range(new_problems.size(0)), dynamic_ncols=True, desc="checking problems" + range(new_quizzes.size(0)), dynamic_ncols=True, desc="checking quizzes" ): - result = new_problems[n][None, :].expand(nb_runs, -1).clone() + result = new_quizzes[n][None, :].expand(nb_runs, -1).clone() ar_mask = ( (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) .long()[None, :] @@ -2276,7 +2281,7 @@ class World(Task): ) nb_correct[n] = ( - (new_problems[n][None, :] == result).long().min(dim=1).values.sum() + (new_quizzes[n][None, :] == result).long().min(dim=1).values.sum() ) - return new_problems, nb_correct + return new_quizzes, nb_correct -- 2.20.1