X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=622cd567ae5d1d845dfe1b0b186fb1851ae0c80e;hb=17267a244c31be85db250706fead811f20158810;hp=1a6c41572ee8c5e2f5e8ce65b9d8cbfea3834f8a;hpb=9e3211bab93700003ff835e346ef413044147b73;p=culture.git diff --git a/tasks.py b/tasks.py index 1a6c415..622cd56 100755 --- a/tasks.py +++ b/tasks.py @@ -2100,7 +2100,7 @@ import world class World(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(self.train_input.to("cpu"), self.height, self.width) + img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) logger(f"wrote {image_name}") @@ -2208,7 +2208,8 @@ class World(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) - logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + main_test_accuracy = test_nb_correct / test_nb_total + logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") ############################## @@ -2225,58 +2226,66 @@ class World(Task): device=self.device, ) - self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger) + self.save_image( + result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger + ) + + return main_test_accuracy + + def store_new_quizzes(self, new_quizzes, for_train=True): + input = self.train_input if for_train else self.test_input - def store_new_problems(self, new_problems): - nb_current = self.train_input.size(0) - nb_new = new_problems.size(0) + nb_current = input.size(0) + nb_new = new_quizzes.size(0) if nb_new >= nb_current: - self.train_input[...] = new_problems[:nb_current] + input[...] = new_quizzes[:nb_current] else: nb_kept = nb_current - nb_new - self.train_input[:nb_kept] = self.train_input[-nb_kept:].clone() - self.train_input[nb_kept:] = new_problems + input[:nb_kept] = input[-nb_kept:].clone() + input[nb_kept:] = new_quizzes - def create_new_problems(self, n_epoch, result_dir, logger, nb, model, nb_runs): - new_problems = torch.empty( + def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs): + new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(new_problems.size(), 1, device=self.device) + ar_mask = torch.full(new_quizzes.size(), 1, device=self.device) masked_inplace_autoregression( model, self.batch_size, - new_problems, + new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new problems", + progress_bar_desc="new quizzes", device=self.device, ) - nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64) + input = ( + new_quizzes[:, None, :] + .expand(-1, nb_runs, -1) + .clone() + .reshape(-1, new_quizzes.size(-1)) + ) + result = input.clone() - for n in tqdm.tqdm( - range(new_problems.size(0)), dynamic_ncols=True, desc="checking problems" - ): - result = new_problems[n][None, :].expand(nb_runs, -1).clone() - ar_mask = ( - (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) - .long()[None, :] - .expand_as(result) - ) + ar_mask = ( + (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) + .long()[None, :] + .expand_as(result) + ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc=None, - device=self.device, - ) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis=False, + progress_bar_desc=None, + device=self.device, + ) - nb_correct[n] = ( - (new_problems[n][None, :] == result).long().min(dim=1).values.sum() - ) + nb_correct = ( + (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1) + ) - return new_problems, nb_correct + return new_quizzes, nb_correct