X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=3db87df303a3cb55e8f2755f7b953485596da00f;hb=d363acfa35249faaa1fc6574e50c1c59da277141;hp=ae4254430653eb236c87b3dfaa31d295f654e05e;hpb=c921b95d0ea5b94a893447fbd4792e5047ba6e99;p=picoclvr.git diff --git a/main.py b/main.py index ae42544..3db87df 100755 --- a/main.py +++ b/main.py @@ -451,13 +451,53 @@ class TaskPicoCLVR(Task): 0, ) - image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png") torchvision.utils.save_image( img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 ) log_string(f"wrote {image_name}") +###################################################################### + + +class TaskMNIST(Task): + def __init__(self, batch_size, device=torch.device("cpu")): + self.device = device + self.batch_size = batch_size + + def batches(self, split="train"): + assert split in {"train", "test"} + data_set = torchvision.datasets.MNIST( + root="./data", train=(split == "train"), download=True + ) + data_input = data_set.data.view(-1, 28 * 28).long() + if args.nb_train_samples is not None: + data_input = data_input[: args.nb_train_samples] + for batch in tqdm.tqdm( + data_input.split(self.batch_size), desc=f"epoch-{split}" + ): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results(self, n_epoch, model): + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) + ar_mask = torch.full_like(results, 1) + masked_inplace_autoregression( + model, self.batch_size, results, ar_mask, device=self.device + ) + image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png") + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + log_string(f"wrote {image_name}") + + ###################################################################### import maze @@ -486,7 +526,7 @@ class TaskMaze(Task): self.width = width self.device = device - train_mazes, train_paths, train_policies = maze.create_maze_data( + train_mazes, train_paths, _ = maze.create_maze_data( nb_train_samples, height=height, width=width, @@ -494,9 +534,8 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) - self.train_policies = train_policies.flatten(-2).to(device) - test_mazes, test_paths, test_policies = maze.create_maze_data( + test_mazes, test_paths, _ = maze.create_maze_data( nb_test_samples, height=height, width=width, @@ -504,9 +543,8 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.test_policies = test_policies.flatten(-2).to(device) - self.nb_codes = self.train_input.max() + 1 + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -520,26 +558,6 @@ class TaskMaze(Task): ): yield batch - def policy_batches(self, split="train", nb_to_use=-1, desc=None): - assert split in {"train", "test"} - input = self.train_input if split == "train" else self.test_input - policies = self.train_policies if split == "train" else self.test_policies - input = input[:, : self.height * self.width] - policies = policies * (input != maze.v_wall)[:, None] - - if nb_to_use > 0: - input = input[:nb_to_use] - policies = policies[:nb_to_use] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - zip(input.split(self.batch_size), policies.split(self.batch_size)), - dynamic_ncols=True, - desc=desc, - ): - yield batch - def vocabulary_size(self): return self.nb_codes @@ -589,9 +607,10 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - filename = f"result_{n_epoch:04d}.png" + + filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") maze.save_image( - os.path.join(args.result_dir, filename), + filename, mazes=mazes, target_paths=paths, predicted_paths=predicted_paths, @@ -602,6 +621,42 @@ class TaskMaze(Task): model.train(t) +###################################################################### + +class TaskSnake(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_walls, + device=torch.device("cpu"), + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + + # self.train_input = + # self.test_input = + + self.nb_codes = max(self.train_input.max(), self.train_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + ###################################################################### @@ -636,6 +691,12 @@ if args.task == "picoclvr": pruner_eval=picoclvr_pruner_eval, ) +elif args.task == "mnist": + task = TaskMNIST( + batch_size=args.batch_size, + device=device, + ) + elif args.task == "maze": task = TaskMaze( nb_train_samples=args.nb_train_samples,