description="An implementation of GPT with cache to solve a toy geometric reasoning task."
)
+parser.add_argument("--task", type=str, default="picoclvr")
+
parser.add_argument("--log_filename", type=str, default="train.log")
parser.add_argument("--result_dir", type=str, default="results_default")
##############################
# picoclvr options
-parser.add_argument("--nb_colors", type=int, default=5)
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
+
+parser.add_argument("--picoclvr_height", type=int, default=12)
+
+parser.add_argument("--picoclvr_width", type=int, default=16)
+
+parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
+
+##############################
+# Maze options
-parser.add_argument("--height", type=int, default=12)
+parser.add_argument("--maze_height", type=int, default=13)
-parser.add_argument("--width", type=int, default=16)
+parser.add_argument("--maze_width", type=int, default=21)
-parser.add_argument("--prune_properties", type=str, default="none")
+parser.add_argument("--maze_nb_walls", type=int, default=15)
######################################################################
args = parser.parse_args()
-assert args.prune_properties in {"none", "train+eval", "eval"}
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
try:
os.mkdir(args.result_dir)
"rng_state": list(torch.get_rng_state()),
}
- log_string(f"generating {nb_train_samples+nb_test_samples} samples (can take some time)")
- self.train_descr = generate_descr(nb_train_samples, "train", pruner=self.pruner_train)
+ log_string(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+ self.train_descr = generate_descr(
+ nb_train_samples, "train", pruner=self.pruner_train
+ )
self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
# Build the tokenizer
######################################################################
-log_string(f"device {device}")
+import maze
+
+class TaskMaze(Task):
+ def map2seq(self, *m):
+ return torch.cat([x.flatten(1) for x in m], 1)
-def pruner_horizontal_green(p):
+ def seq2map(self, s):
+ s = s.reshape(s.size(0), -1, self.height, self.width)
+ return (s[:, k] for k in range(s.size(1)))
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ nb_walls,
+ device=torch.device("cpu"),
+ ):
+ self.batch_size = batch_size
+ self.height = height
+ self.width = width
+ self.device = device
+
+ train_mazes, train_paths, train_policies = maze.create_maze_data(
+ nb_train_samples,
+ height=height,
+ width=width,
+ nb_walls=nb_walls,
+ progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
+ )
+ self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
+ self.train_policies = train_policies.flatten(-2).to(device)
+
+ test_mazes, test_paths, test_policies = maze.create_maze_data(
+ nb_test_samples,
+ height=height,
+ width=width,
+ nb_walls=nb_walls,
+ progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
+ )
+ self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
+ self.test_policies = test_policies.flatten(-2).to(device)
+
+ self.nb_codes = self.train_input.max() + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def policy_batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ policies = self.train_policies if split == "train" else self.test_policies
+ input = input[:, : self.height * self.width]
+ policies = policies * (input != maze.v_wall)[:, None]
+
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ policies = policies[:nb_to_use]
+
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ zip(input.split(self.batch_size), policies.split(self.batch_size)),
+ dynamic_ncols=True,
+ desc=desc,
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def compute_error(self, model, split="train", nb_to_use=-1):
+ nb_total, nb_correct = 0, 0
+ for input in task.batches(split, nb_to_use):
+ result = input.clone()
+ ar_mask = result.new_zeros(result.size())
+ ar_mask[:, self.height * self.width :] = 1
+ result *= 1 - ar_mask
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, device=self.device
+ )
+ mazes, paths = self.seq2map(result)
+ nb_correct += maze.path_correctness(mazes, paths).long().sum()
+ nb_total += mazes.size(0)
+
+ return nb_total, nb_correct
+
+ def produce_results(self, n_epoch, model):
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ train_nb_total, train_nb_correct = self.compute_error(
+ model, "train", nb_to_use=1000
+ )
+ log_string(
+ f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+ )
+
+ test_nb_total, test_nb_correct = self.compute_error(
+ model, "test", nb_to_use=1000
+ )
+ log_string(
+ f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ input = self.test_input[:48]
+ result = input.clone()
+ ar_mask = result.new_zeros(result.size())
+ ar_mask[:, self.height * self.width :] = 1
+ result *= 1 - ar_mask
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, device=self.device
+ )
+
+ mazes, paths = self.seq2map(input)
+ _, predicted_paths = self.seq2map(result)
+ filename = f"result_{n_epoch:04d}.png"
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ target_paths=paths,
+ predicted_paths=predicted_paths,
+ path_correct=maze.path_correctness(mazes, predicted_paths),
+ )
+ log_string(f"wrote {filename}")
+
+ model.train(t)
+
+
+######################################################################
+
+
+def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
-task = TaskPicoCLVR(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
- height=args.height,
- width=args.width,
- nb_colors=args.nb_colors,
- device=device,
- pruner_train=pruner_horizontal_green
- if args.prune_properties in {"train+eval"}
- else None,
- pruner_eval=(lambda p: not pruner_horizontal_green(p))
- if args.prune_properties in {"train+eval", "eval"}
- else None,
+picoclvr_pruner_train = (
+ picoclvr_pruner_horizontal_green
+ if args.picocvlr_prune_properties in {"train+eval"}
+ else None
+)
+
+picoclvr_pruner_eval = (
+ (lambda p: not picoclvr_pruner_horizontal_green(p))
+ if args.picocvlr_prune_properties in {"train+eval", "eval"}
+ else None
)
+######################################################################
+
+if args.task == "picoclvr":
+ task = TaskPicoCLVR(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ nb_colors=args.picoclvr_nb_colors,
+ device=device,
+ pruner_train=picoclvr_pruner_train,
+ pruner_eval=picoclvr_pruner_eval,
+ )
+
+elif args.task == "maze":
+ task = TaskMaze(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.maze_height,
+ width=args.maze_width,
+ nb_walls=args.maze_nb_walls,
+ device=device,
+ )
+
+else:
+ raise ValueError(f"Unknown task {args.task}")
+
+######################################################################
+
+log_string(f"device {device}")
+
vocabulary_size = task.vocabulary_size()
log_string(f"vocabulary_size {vocabulary_size}")
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
+
+
+def create_maze(h=11, w=17, nb_walls=8):
+ a, k = 0, 0
+
+ while k < nb_walls:
+ while True:
+ if a == 0:
+ m = torch.zeros(h, w, dtype=torch.int64)
+ m[0, :] = 1
+ m[-1, :] = 1
+ m[:, 0] = 1
+ m[:, -1] = 1
+
+ r = torch.rand(4)
+
+ if r[0] <= 0.5:
+ i1, i2, j = (
+ int((r[1] * h).item()),
+ int((r[2] * h).item()),
+ int((r[3] * w).item()),
+ )
+ i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
+ i1, i2 = min(i1, i2), max(i1, i2)
+ if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+ m[i1 : i2 + 1, j] = 1
+ break
+ else:
+ i, j1, j2 = (
+ int((r[1] * h).item()),
+ int((r[2] * w).item()),
+ int((r[3] * w).item()),
+ )
+ i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
+ j1, j2 = min(j1, j2), max(j1, j2)
+ if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+ m[i, j1 : j2 + 1] = 1
+ break
+ a += 1
+
+ if a > 10 * nb_walls:
+ a, k = 0, 0
+
+ k += 1
+
+ return m
+
+
+######################################################################
+
+
+def compute_distance(walls, goal_i, goal_j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[goal_i, goal_j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ d = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
+
+ dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
+
+######################################################################
+
+
+def compute_policy(walls, goal_i, goal_j):
+ distance = compute_distance(walls, goal_i, goal_j)
+ distance = distance + walls.numel() * walls
+
+ value = distance.new_full((4,) + distance.size(), walls.numel())
+ value[0, :, 1:] = distance[:, :-1] # <
+ value[1, :, :-1] = distance[:, 1:] # >
+ value[2, 1:, :] = distance[:-1, :] # ^
+ value[3, :-1, :] = distance[1:, :] # v
+
+ proba = (value.min(dim=0)[0][None] == value).float()
+ proba = proba / proba.sum(dim=0)[None]
+ proba = proba * (1 - walls) + walls.float() / 4
+
+ return proba
+
+
+def stationary_densities(mazes, policies):
+ policies = policies * (mazes != v_goal)[:, None]
+ start = (mazes == v_start).nonzero(as_tuple=True)
+ probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
+ pred_probas = probas.clone()
+ probas[start] = 1.0
+
+ while not pred_probas.equal(probas):
+ pred_probas.copy_(probas)
+ probas.zero_()
+ probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+ probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+ probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+ probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+ probas[start] = 1.0
+
+ return probas
+
+
+######################################################################
+
+
+def mark_path(walls, i, j, goal_i, goal_j, policy):
+ action = torch.distributions.categorical.Categorical(
+ policy.permute(1, 2, 0)
+ ).sample()
+ n, nmax = 0, walls.numel()
+ while i != goal_i or j != goal_j:
+ di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
+ i, j = i + di, j + dj
+ assert walls[i, j] == 0
+ walls[i, j] = v_path
+ n += 1
+ assert n < nmax
+
+
+def path_correctness(mazes, paths):
+ still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0
+ reached = still_ok.new_zeros(still_ok.size())
+ current, pred_current = paths.clone(), paths.new_zeros(paths.size())
+ goal = (mazes == v_goal).long()
+ while not pred_current.equal(current):
+ pred_current.copy_(current)
+ u = (current == v_start).long()
+ possible_next = (
+ u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
+ ).long()
+ u = u[:, 1:-1, 1:-1]
+ reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
+ (current == v_path).sum((1, 2)) == 0
+ )
+ current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
+ v_start - v_path
+ ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
+ still_ok *= (current == v_start).sum((1, 2)) <= 1
+
+ return still_ok * reached
+
+
+######################################################################
+
+
+def create_maze_data(
+ nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
+):
+ mazes = torch.empty(nb, height, width, dtype=torch.int64)
+ paths = torch.empty(nb, height, width, dtype=torch.int64)
+ policies = torch.empty(nb, 4, height, width)
+
+ for n in progress_bar(range(nb)):
+ maze = create_maze(height, width, nb_walls)
+ i = (maze == v_empty).nonzero()
+ while True:
+ start, goal = i[torch.randperm(i.size(0))[:2]]
+ if (start - goal).abs().sum() >= dist_min:
+ break
+ start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
+
+ policy = compute_policy(maze, goal_i, goal_j)
+ path = maze.clone()
+ mark_path(path, start_i, start_j, goal_i, goal_j, policy)
+ maze[start_i, start_j] = v_start
+ maze[goal_i, goal_j] = v_goal
+ path[start_i, start_j] = v_start
+ path[goal_i, goal_j] = v_goal
+
+ mazes[n] = maze
+ paths[n] = path
+ policies[n] = policy
+
+ return mazes, paths, policies
+
+
+######################################################################
+
+
+def save_image(
+ name,
+ mazes,
+ target_paths=None,
+ predicted_paths=None,
+ score_paths=None,
+ score_truth=None,
+ path_correct=None,
+):
+ colors = torch.tensor(
+ [
+ [255, 255, 255], # empty
+ [0, 0, 0], # wall
+ [0, 255, 0], # start
+ [127, 127, 255], # goal
+ [255, 0, 0], # path
+ ]
+ )
+
+ mazes = mazes.cpu()
+
+ c_mazes = (
+ colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+ )
+
+ if score_truth is not None:
+ score_truth = score_truth.cpu()
+ c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+ c_score_truth = (
+ c_score_truth * colors[4].reshape(1, 3, 1, 1)
+ + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+ ).long()
+ c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+ mazes.unsqueeze(1) == v_empty
+ ) * c_score_truth
+
+ imgs = c_mazes.unsqueeze(1)
+
+ if target_paths is not None:
+ target_paths = target_paths.cpu()
+
+ c_target_paths = (
+ colors[target_paths.reshape(-1)]
+ .reshape(target_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+
+ imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
+
+ if predicted_paths is not None:
+ predicted_paths = predicted_paths.cpu()
+ c_predicted_paths = (
+ colors[predicted_paths.reshape(-1)]
+ .reshape(predicted_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+ imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+ if score_paths is not None:
+ score_paths = score_paths.cpu()
+ c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+ c_score_paths = (
+ c_score_paths * colors[4].reshape(1, 3, 1, 1)
+ + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1)
+ ).long()
+ c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+ mazes.unsqueeze(1) != v_empty
+ )
+ imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
+
+ # NxKxCxHxW
+ if path_correct is None:
+ path_correct = torch.zeros(imgs.size(0)) <= 1
+ path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+ img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+ [255, 0, 0]
+ ).view(1, -1, 1, 1) * (1 - path_correct)
+ img = img.expand(
+ -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+ ).clone()
+ for k in range(imgs.size(1)):
+ img[
+ :,
+ :,
+ 1 : 1 + imgs.size(3),
+ 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+ ] = imgs[:, k]
+
+ img = img.float() / 255.0
+
+ torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
+
+
+######################################################################
+
+if __name__ == "__main__":
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ mazes, paths = create_maze_data(8)
+ mazes, paths = mazes.to(device), paths.to(device)
+ save_image("test.png", mazes, paths, paths)
+ print(path_correctness(mazes, paths))
+
+######################################################################