X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=43d290049cf81c372821c6a13463a0b285d65397;hb=5fff2918fdcc35016195cd209afc864e9cd2ac32;hp=ae4254430653eb236c87b3dfaa31d295f654e05e;hpb=c921b95d0ea5b94a893447fbd4792e5047ba6e99;p=picoclvr.git diff --git a/main.py b/main.py index ae42544..43d2900 100755 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ # torch.backends.cuda.matmul.allow_tf23 # torch.autocast(torch.bfloat16) -import math, sys, argparse, time, tqdm, itertools, os +import math, sys, argparse, time, tqdm, os import torch, torchvision from torch import nn @@ -27,7 +27,8 @@ else: ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasoning task." + description="An implementation of GPT with cache.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--task", type=str, default="picoclvr") @@ -40,7 +41,7 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--nb_epochs", type=int, default=25) -parser.add_argument("--batch_size", type=int, default=25) +parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--nb_train_samples", type=int, default=250000) @@ -92,6 +93,17 @@ parser.add_argument("--maze_width", type=int, default=21) parser.add_argument("--maze_nb_walls", type=int, default=15) +############################## +# Snake options + +parser.add_argument("--snake_height", type=int, default=6) + +parser.add_argument("--snake_width", type=int, default=8) + +parser.add_argument("--snake_nb_colors", type=int, default=3) + +parser.add_argument("--snake_length", type=int, default=400) + ###################################################################### args = parser.parse_args() @@ -117,6 +129,28 @@ if args.seed >= 0: ###################################################################### +default_args = { + "picoclvr": { + "batch_size": 25, + }, + "mnist": { + "batch_size": 10, + }, + "maze": { + "batch_size": 25, + }, + "snake": { + "batch_size": 20, + }, +} + +if args.task in default_args: + for k, v in default_args[args.task].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + def log_string(s): t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) @@ -138,7 +172,12 @@ for n in vars(args): def masked_inplace_autoregression( model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") ): - for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): + for input, ar_mask in tqdm.tqdm( + zip(input.split(batch_size), ar_mask.split(batch_size)), + dynamic_ncols=True, + desc="autoregression", + total=input.size(0) // batch_size, + ): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: model( @@ -451,13 +490,53 @@ class TaskPicoCLVR(Task): 0, ) - image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png") torchvision.utils.save_image( img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 ) log_string(f"wrote {image_name}") +###################################################################### + + +class TaskMNIST(Task): + def __init__(self, batch_size, device=torch.device("cpu")): + self.device = device + self.batch_size = batch_size + + def batches(self, split="train"): + assert split in {"train", "test"} + data_set = torchvision.datasets.MNIST( + root="./data", train=(split == "train"), download=True + ) + data_input = data_set.data.view(-1, 28 * 28).long() + if args.nb_train_samples is not None: + data_input = data_input[: args.nb_train_samples] + for batch in tqdm.tqdm( + data_input.split(self.batch_size), desc=f"epoch-{split}" + ): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results(self, n_epoch, model): + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) + ar_mask = torch.full_like(results, 1) + masked_inplace_autoregression( + model, self.batch_size, results, ar_mask, device=self.device + ) + image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + log_string(f"wrote {image_name}") + + ###################################################################### import maze @@ -486,7 +565,7 @@ class TaskMaze(Task): self.width = width self.device = device - train_mazes, train_paths, train_policies = maze.create_maze_data( + train_mazes, train_paths, _ = maze.create_maze_data( nb_train_samples, height=height, width=width, @@ -494,9 +573,8 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) - self.train_policies = train_policies.flatten(-2).to(device) - test_mazes, test_paths, test_policies = maze.create_maze_data( + test_mazes, test_paths, _ = maze.create_maze_data( nb_test_samples, height=height, width=width, @@ -504,9 +582,8 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.test_policies = test_policies.flatten(-2).to(device) - self.nb_codes = self.train_input.max() + 1 + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -520,26 +597,6 @@ class TaskMaze(Task): ): yield batch - def policy_batches(self, split="train", nb_to_use=-1, desc=None): - assert split in {"train", "test"} - input = self.train_input if split == "train" else self.test_input - policies = self.train_policies if split == "train" else self.test_policies - input = input[:, : self.height * self.width] - policies = policies * (input != maze.v_wall)[:, None] - - if nb_to_use > 0: - input = input[:nb_to_use] - policies = policies[:nb_to_use] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - zip(input.split(self.batch_size), policies.split(self.batch_size)), - dynamic_ncols=True, - desc=desc, - ): - yield batch - def vocabulary_size(self): return self.nb_codes @@ -589,9 +646,10 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - filename = f"result_{n_epoch:04d}.png" + + filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png") maze.save_image( - os.path.join(args.result_dir, filename), + filename, mazes=mazes, target_paths=paths, predicted_paths=predicted_paths, @@ -605,6 +663,217 @@ class TaskMaze(Task): ###################################################################### +def generate_snake_sequences( + nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") +): + worlds = torch.randint(nb_colors, (nb, height, width), device=device) + nb_prior_visits = torch.zeros(nb, height, width, device=device) + + # nb x 2 + snake_position = torch.cat( + ( + torch.randint(height, (nb, 1), device=device), + torch.randint(width, (nb, 1), device=device), + ), + 1, + ) + snake_direction = torch.randint(4, (nb,), device=device) + sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) + sequences_prior_visits = torch.zeros( + nb, 2 * length, device=device, dtype=torch.int64 + ) + i = torch.arange(nb, device=device) # [:,None] + + for l in range(length): + # nb x 3 + snake_next_direction = torch.cat( + ( + (snake_direction[:, None] - 1) % 4, + snake_direction[:, None], + (snake_direction[:, None] + 1) % 4, + ), + 1, + ) + + # nb x 3 + vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1) + vw = snake_next_direction % 2 * (snake_next_direction - 2) + + # nb x 3 x 2 + snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2) + snake_next_position = snake_position[:, None, :] + snake_next_speed + + # nb x 3 + val = torch.logical_and( + torch.logical_and( + snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height + ), + torch.logical_and( + snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width + ), + ).float() + val = ( + # The multiplicative factors bias toward moving forward + torch.rand_like(val) + * val + * torch.tensor([[1.0, 2.0, 1.0]], device=device) + ) + + # nb + j = val.argmax(1) + snake_direction = snake_next_direction[i, j] + + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 + sequences_prior_visits[:, 2 * l] = nb_prior_visits[ + i, snake_position[:, 0], snake_position[:, 1] + ] + if l < prompt_length: + nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + sequences[:, 2 * l + 1] = snake_direction + + # nb x 2 + snake_position = snake_next_position[i, j] + + return sequences, sequences_prior_visits + + +# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) +# exit(0) + + +def snake_solver(input, ar_mask): + for n in range(input.size(0)): + i, j, memory = 0, 0, {} + # print(input[n]) + # print(ar_mask[n]) + for l in range(input.size(1) // 2): + if ar_mask[n, 2 * l] == 1: + if memory.get((i, j)) is None: + input[n, 2 * l] = -1 + else: + input[n, 2 * l] = memory[(i, j)] + else: + # print(f'@3 {memory=}') + if memory.get((i, j)) is None: + memory[(i, j)] = input[n, 2 * l] + else: + assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}" + # print(f'@1 {i=} {j=}') + d = input[n, 2 * l + 1].item() + i += (d + 1) % 2 * (d - 1) + j += d % 2 * (d - 2) + # print(f'@2 {i=} {j=}') + + +class TaskSnake(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_colors, + length, + prompt_length, + device=torch.device("cpu"), + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + self.prompt_length = prompt_length + + self.train_input, self.train_prior_visits = generate_snake_sequences( + nb_train_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + self.test_input, self.test_prior_visits = generate_snake_sequences( + nb_test_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + def compute_nb_correct(input, prior_visits): + result = input.clone() + i = torch.arange(result.size(1), device=result.device)[None, :] + ar_mask = ( + torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0) + .long() + .expand_as(result) + ) + result *= 1 - ar_mask + + # snake_solver(result,ar_mask) + + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + nb_total = ((prior_visits > 0) * ar_mask).sum() + + nb_correct = ( + (result == input).long() * (prior_visits > 0) * ar_mask + ).sum() + + # nb_total = result.size(0) + # nb_correct = ((result - input).abs().sum(1) == 0).sum() + + return nb_total, nb_correct + + # train_nb_total, train_nb_correct = compute_nb_correct( + # self.train_input, self.train_prior_visits + # ) + + # log_string( + # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + # ) + + test_nb_total, test_nb_correct = compute_nb_correct( + self.test_input[:1000], self.test_prior_visits[:1000] + ) + + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + model.train(t) + + +###################################################################### + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -636,6 +905,12 @@ if args.task == "picoclvr": pruner_eval=picoclvr_pruner_eval, ) +elif args.task == "mnist": + task = TaskMNIST( + batch_size=args.batch_size, + device=device, + ) + elif args.task == "maze": task = TaskMaze( nb_train_samples=args.nb_train_samples, @@ -647,6 +922,19 @@ elif args.task == "maze": device=device, ) +elif args.task == "snake": + task = TaskSnake( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.snake_height, + width=args.snake_width, + nb_colors=args.snake_nb_colors, + length=args.snake_length, + prompt_length=args.snake_length // 2, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}")