X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=49b83ecd3c4a1c6b5763ed41e7f9a97d5c40de1f;hb=d16410119a4e5c1117f7f0fbbe80e3e54f81f28b;hp=1b28108079905addada818ceab335183688434b3;hpb=b9924ef4309a69af2af150dba73125ed5fed0093;p=culture.git diff --git a/tasks.py b/tasks.py index 1b28108..49b83ec 100755 --- a/tasks.py +++ b/tasks.py @@ -2099,11 +2099,18 @@ import world class World(Task): + def save_image(self, input, result_dir, filename, logger): + img = world.sample2img(self.train_input.to("cpu"), self.height, self.width) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + logger(f"wrote {image_name}") + def __init__( self, nb_train_samples, nb_test_samples, batch_size, + result_dir=None, logger=None, device=torch.device("cpu"), ): @@ -2141,6 +2148,11 @@ class World(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + if result_dir is not None: + self.save_image( + self.train_input[:96], result_dir, f"world_train.png", logger + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -2196,11 +2208,12 @@ class World(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) - logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + main_test_accuracy = test_nb_correct / test_nb_total + logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") ############################## - input, ar_mask = self.test_input[:64], self.test_ar_mask[:64] + input, ar_mask = self.test_input[:96], self.test_ar_mask[:96] result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( @@ -2213,38 +2226,44 @@ class World(Task): device=self.device, ) - img = world.sample2img(result.to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) - logger(f"wrote {image_name}") + self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger) - def create_new_problems(self, n_epoch, result_dir, logger, nb, model, nb_runs): - new_problems = torch.empty( + return main_test_accuracy + + def store_new_quizzes(self, new_quizzes, for_train=True): + input = self.train_input if for_train else self.test_input + + nb_current = input.size(0) + nb_new = new_quizzes.size(0) + if nb_new >= nb_current: + input[...] = new_quizzes[:nb_current] + else: + nb_kept = nb_current - nb_new + input[:nb_kept] = input[-nb_kept:].clone() + input[nb_kept:] = new_quizzes + + def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs): + new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(new_problems.size(), 1, device=self.device) + ar_mask = torch.full(new_quizzes.size(), 1, device=self.device) masked_inplace_autoregression( model, self.batch_size, - new_problems, + new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new problems", + progress_bar_desc="new quizzes", device=self.device, ) - img = world.sample2img(new_problems[:64].to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, f"world_new_{n_epoch:04d}.png") - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) - logger(f"wrote {image_name}") - nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64) for n in tqdm.tqdm( - range(new_problems.size(0)), dynamic_ncols=True, desc="checking problems" + range(new_quizzes.size(0)), dynamic_ncols=True, desc="checking quizzes" ): - result = new_problems[n][None, :].expand(nb_runs, -1).clone() + result = new_quizzes[n][None, :].expand(nb_runs, -1).clone() ar_mask = ( (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) .long()[None, :] @@ -2262,7 +2281,7 @@ class World(Task): ) nb_correct[n] = ( - (new_problems[n][None, :] == result).long().min(dim=1).values.sum() + (new_quizzes[n][None, :] == result).long().min(dim=1).values.sum() ) - return new_problems, nb_correct + return new_quizzes, nb_correct