From 17267a244c31be85db250706fead811f20158810 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 19:45:08 +0200 Subject: [PATCH] Update. --- main.py | 15 ++++++++++----- tasks.py | 52 ++++++++++++++++++++++++++++------------------------ 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/main.py b/main.py index ca0d152..672dab5 100755 --- a/main.py +++ b/main.py @@ -906,6 +906,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): # -------------------------------------------- if test_accuracy >= 0.8: + nb_runs, nb_min_correct, nb_max_correct = 10, 8, 9 nb_for_train, nb_for_test = 1000, 100 kept = [] @@ -914,19 +915,23 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, - nb=nb_required, + nb=4 * (nb_for_train + nb_for_test), model=model, - nb_runs=10, + nb_runs=nb_runs, ) - to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)] + to_keep = new_quizzes[ + torch.logical_and( + nb_correct >= nb_min_correct, nb_correct <= nb_max_correct + ) + ] log_string(f"keep {to_keep.size(0)} quizzes") kept.append(to_keep) new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] - task.store_new_quizzes(new_quizzes[:nb_for_train], train=True) - task.store_new_quizzes(new_quizzes[nb_for_train:], train=False) + task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True) + task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False) task.save_image( new_quizzes[:96], diff --git a/tasks.py b/tasks.py index 49b83ec..622cd56 100755 --- a/tasks.py +++ b/tasks.py @@ -2100,7 +2100,7 @@ import world class World(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(self.train_input.to("cpu"), self.height, self.width) + img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) logger(f"wrote {image_name}") @@ -2226,7 +2226,9 @@ class World(Task): device=self.device, ) - self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger) + self.save_image( + result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger + ) return main_test_accuracy @@ -2258,30 +2260,32 @@ class World(Task): device=self.device, ) - nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64) + input = ( + new_quizzes[:, None, :] + .expand(-1, nb_runs, -1) + .clone() + .reshape(-1, new_quizzes.size(-1)) + ) + result = input.clone() - for n in tqdm.tqdm( - range(new_quizzes.size(0)), dynamic_ncols=True, desc="checking quizzes" - ): - result = new_quizzes[n][None, :].expand(nb_runs, -1).clone() - ar_mask = ( - (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) - .long()[None, :] - .expand_as(result) - ) + ar_mask = ( + (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) + .long()[None, :] + .expand_as(result) + ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc=None, - device=self.device, - ) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis=False, + progress_bar_desc=None, + device=self.device, + ) - nb_correct[n] = ( - (new_quizzes[n][None, :] == result).long().min(dim=1).values.sum() - ) + nb_correct = ( + (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1) + ) return new_quizzes, nb_correct -- 2.20.1