From 9e3211bab93700003ff835e346ef413044147b73 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 16:45:09 +0200 Subject: [PATCH] Update. --- main.py | 10 +++++++++- tasks.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 35f02a3..3acf595 100755 --- a/main.py +++ b/main.py @@ -474,6 +474,7 @@ elif args.task == "world": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, + result_dir=args.result_dir, logger=log_string, device=device, ) @@ -902,7 +903,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): # -------------------------------------------- if n_epoch >= 3: - nb_required = 1000 + nb_required = 100 kept = [] while sum([x.size(0) for x in kept]) < nb_required: @@ -920,6 +921,13 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): kept.append(to_keep) new_problems = torch.cat(kept, dim=0)[:nb_required] + task.store_new_problems(new_problems) + task.save_image( + new_problems[:96], + args.result_dir, + f"world_new_{n_epoch:04d}.png", + log_string, + ) # -------------------------------------------- diff --git a/tasks.py b/tasks.py index 1b28108..1a6c415 100755 --- a/tasks.py +++ b/tasks.py @@ -2099,11 +2099,18 @@ import world class World(Task): + def save_image(self, input, result_dir, filename, logger): + img = world.sample2img(self.train_input.to("cpu"), self.height, self.width) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + logger(f"wrote {image_name}") + def __init__( self, nb_train_samples, nb_test_samples, batch_size, + result_dir=None, logger=None, device=torch.device("cpu"), ): @@ -2141,6 +2148,11 @@ class World(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + if result_dir is not None: + self.save_image( + self.train_input[:96], result_dir, f"world_train.png", logger + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -2200,7 +2212,7 @@ class World(Task): ############################## - input, ar_mask = self.test_input[:64], self.test_ar_mask[:64] + input, ar_mask = self.test_input[:96], self.test_ar_mask[:96] result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( @@ -2213,10 +2225,17 @@ class World(Task): device=self.device, ) - img = world.sample2img(result.to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) - logger(f"wrote {image_name}") + self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger) + + def store_new_problems(self, new_problems): + nb_current = self.train_input.size(0) + nb_new = new_problems.size(0) + if nb_new >= nb_current: + self.train_input[...] = new_problems[:nb_current] + else: + nb_kept = nb_current - nb_new + self.train_input[:nb_kept] = self.train_input[-nb_kept:].clone() + self.train_input[nb_kept:] = new_problems def create_new_problems(self, n_epoch, result_dir, logger, nb, model, nb_runs): new_problems = torch.empty( @@ -2234,11 +2253,6 @@ class World(Task): device=self.device, ) - img = world.sample2img(new_problems[:64].to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, f"world_new_{n_epoch:04d}.png") - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) - logger(f"wrote {image_name}") - nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64) for n in tqdm.tqdm( -- 2.20.1