--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# torch.backends.cuda.matmul.allow_tf23
+# torch.autocast(torch.bfloat16)
+
+import math, sys, argparse, time, tqdm, itertools, os
+
+import torch, torchvision
+from torch import nn
+from torch.nn import functional as F
+
+import mygpt, tensorstack
+
+######################################################################
+
+if torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+)
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--result_dir", type=str, default="results_default")
+
+parser.add_argument("--seed", type=int, default=0)
+
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--data_size", type=int, default=-1)
+
+parser.add_argument("--optim", type=str, default="adam")
+
+parser.add_argument("--learning_rate", type=float, default=1e-3)
+
+parser.add_argument(
+ "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
+)
+
+parser.add_argument("--dim_model", type=int, default=512)
+
+parser.add_argument("--dim_keys", type=int, default=64)
+
+parser.add_argument("--dim_hidden", type=int, default=2048)
+
+parser.add_argument("--nb_heads", type=int, default=8)
+
+parser.add_argument("--nb_blocks", type=int, default=12)
+
+parser.add_argument("--dropout", type=float, default=0.1)
+
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
+
+parser.add_argument("--overwrite_results", action="store_true", default=False)
+
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
+
+##############################
+# picoclvr options
+
+parser.add_argument("--world_height", type=int, default=23)
+
+parser.add_argument("--world_width", type=int, default=31)
+
+parser.add_argument("--world_nb_walls", type=int, default=15)
+
+######################################################################
+
+args = parser.parse_args()
+
+assert args.prune_properties in {"none", "train+eval", "eval"}
+
+try:
+ os.mkdir(args.result_dir)
+except FileExistsError:
+ if not args.overwrite_results:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
+
+log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
+
+if args.seed >= 0:
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # torch.use_deterministic_algorithms(True)
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
+######################################################################
+
+
+def log_string(s):
+ t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
+
+ if log_file is not None:
+ log_file.write(t + s + "\n")
+ log_file.flush()
+
+ print(t + s)
+ sys.stdout.flush()
+
+
+for n in vars(args):
+ log_string(f"args.{n} {getattr(args, n)}")
+
+######################################################################
+
+
+def masked_inplace_autoregression(
+ model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+):
+
+ for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+ i = (ar_mask.sum(0) > 0).nonzero()
+ if i.min() > 0:
+ model(
+ mygpt.BracketedSequence(input, 0, i.min())
+ ) # Needed to initialize the model's cache
+ for s in range(i.min(), i.max() + 1):
+ output = model(mygpt.BracketedSequence(input, s, 1)).x
+ logits = output[:, s]
+ if forbidden_tokens is not None:
+ logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+ if args.deterministic_synthesis:
+ t_next = logits.argmax(1)
+ else:
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ t_next = dist.sample()
+ input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+
+######################################################################
+
+
+class Task:
+ def batches(self, split="train"):
+ pass
+
+ def vocabulary_size(self):
+ pass
+
+ def produce_results(self, n_epoch, model):
+ pass
+
+
+######################################################################
+
+import picoclvr
+
+
+class TaskPicoCLVR(Task):
+
+ # Make a tensor from a list of strings
+ def tensorize(self, descr):
+ token_descr = [s.strip().split(" ") for s in descr]
+ l = max([len(s) for s in token_descr])
+ token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+ id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+ return torch.tensor(id_descr, device=self.device)
+
+ # Make a list of strings from a tensor
+ def detensorize(self, x):
+ return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+ # trim all the tensors in the tuple z to remove as much token from
+ # left and right in the first tensor. If z is a tuple, all its
+ # elements are trimed according to the triming for the first
+ def trim(self, z, token="<nul>"):
+ n = self.token2id[token]
+ if type(z) == tuple:
+ x = z[0]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return tuple([t[:, a:b] for t in z])
+ else:
+ i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return z[:, a:b]
+
+ ######################
+ # Not the cleanest part of the code
+
+ # Extract the last image of each sequence, from the last <img>
+ # included, and set to <nul> all the tokens from the beginning of
+ # that image to the end
+ def excise_last_image(self, input):
+ t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+ nb_img_tokens = self.height * self.width + 1
+
+ input = input.clone()
+ t = (input == t_img).long()
+ tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
+ i = (t * tail_masks).nonzero(as_tuple=True)
+ j = (
+ i[0][:, None],
+ i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
+ )
+ images = self.trim(input[j])
+ input[j] = t_nul
+ loss_masks = 1 - tail_masks
+ input, loss_masks = self.trim((input, loss_masks))
+ return input, loss_masks, images
+
+ def add_true_image(self, input, images, loss_masks):
+ t_nul = self.token2id["<nul>"]
+ nb_img_tokens = self.height * self.width + 1
+ input = F.pad(input, (0, nb_img_tokens), value=t_nul)
+ loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
+ t = (input == t_nul).long()
+ i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
+ j = (
+ i[0][:, None],
+ i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
+ )
+ input[j] = images
+ loss_masks[j] = 1
+ input, loss_masks = self.trim((input, loss_masks))
+ return input, loss_masks
+
+ def add_generated_image(self, input, loss_masks, model):
+ t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+ nb_img_tokens = self.height * self.width + 1
+
+ input = F.pad(input, (0, nb_img_tokens), value=t_nul)
+ loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
+ t = (input == t_nul).long()
+ i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
+ input[i] = t_img
+
+ j = (
+ i[0][:, None],
+ i[1][:, None]
+ + 1
+ + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
+ )
+ ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
+ ar_masks[j] = 1
+ forbidden_tokens = (
+ torch.arange(self.vocabulary_size(), device=input.device) == t_nul
+ )
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ input,
+ ar_masks,
+ forbidden_tokens,
+ device=self.device,
+ )
+ model.train(t)
+
+ input, loss_masks = self.trim((input, loss_masks))
+
+ return input, loss_masks
+
+ ######################
+
+ def __init__(
+ self,
+ batch_size,
+ height,
+ width,
+ nb_colors=5,
+ device=torch.device("cpu"),
+ pruner_train=None,
+ pruner_eval=None,
+ ):
+ def generate_descr(nb, cache_suffix, pruner):
+ return picoclvr.generate(
+ nb,
+ height=self.height,
+ width=self.width,
+ nb_colors=nb_colors,
+ pruner=pruner,
+ )
+
+ self.height = height
+ self.width = width
+ self.batch_size = batch_size
+ self.device = device
+ nb = args.data_size if args.data_size > 0 else 250000
+ self.pruner_train = pruner_train
+ self.pruner_eval = pruner_eval
+
+ param = {
+ "nb": nb,
+ "height": height,
+ "width": width,
+ "nb_colors": nb_colors,
+ "batch_size": batch_size,
+ "rng_state": list(torch.get_rng_state()),
+ }
+
+ log_string(f"generating {nb} samples (can take some time)")
+ self.train_descr = generate_descr(
+ (nb * 4) // 5, "train", pruner=self.pruner_train
+ )
+ self.test_descr = generate_descr((nb * 1) // 5, "test", pruner=None)
+
+ # Build the tokenizer
+ tokens = {"<nul>", "<img>"}
+ for d in [self.train_descr, self.test_descr]:
+ for s in d:
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ # make this set a sorted list to get the same tensors given
+ # the same descr
+ tokens = list(tokens)
+ tokens.sort()
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+
+ # Tokenize the train and test sets
+ self.train_input = self.tensorize(self.train_descr)
+ self.test_input = self.tensorize(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield self.trim(batch)
+
+ def vocabulary_size(self):
+ return len(self.token2id)
+
+ def compute_missing_properties(self, n_epoch, model, pruner=None):
+
+ acc_nb_requested_properties = []
+ acc_nb_missing_properties = []
+ acc_nb_results = 0
+
+ for input in tqdm.tqdm(
+ self.test_input.split(self.batch_size),
+ dynamic_ncols=True,
+ desc=f"test-properties",
+ ):
+ tape, loss_masks, _ = self.excise_last_image(input)
+ tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
+ result_descr = self.detensorize(tape)
+ np = picoclvr.nb_properties(
+ result_descr,
+ height=self.height,
+ width=self.width,
+ pruner=pruner,
+ )
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+ acc_nb_requested_properties += nb_requested_properties
+ acc_nb_missing_properties += nb_missing_properties
+ acc_nb_results += len(result_descr)
+
+ nb_requested_properties = sum(acc_nb_requested_properties)
+ nb_missing_properties = sum(acc_nb_missing_properties)
+
+ prefix = "" if pruner is None else "pruned_"
+ log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+ log_string(
+ f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+ )
+ log_string(
+ f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+ )
+
+ ######################################################################
+
+ def produce_results(self, n_epoch, model):
+
+ self.compute_missing_properties(n_epoch, model)
+
+ if self.pruner_eval is not None:
+ self.compute_missing_properties(n_epoch, model, self.pruner_eval)
+
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = []
+ nb_per_primer = 8
+ primer = []
+
+ for primer_descr in [
+ "red above green <sep> green top <sep> blue right of red",
+ "there is red <sep> there is yellow <sep> there is blue",
+ "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
+ "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
+ ]:
+ primer += [primer_descr] * nb_per_primer
+
+ tape = self.tensorize(primer)
+ loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
+ tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
+ result_descr = self.detensorize(tape)
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
+ acc_nb_results = len(result_descr)
+
+ nb_requested_properties = sum(acc_nb_requested_properties)
+ nb_missing_properties = sum(acc_nb_missing_properties)
+
+ prefix = "demo_"
+ log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+ log_string(
+ f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+ )
+ log_string(
+ f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+ )
+
+ img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
+
+ if img.dim() == 5:
+ if img.size(1) == 1:
+ img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
+ else:
+ img = torch.cat(
+ [
+ torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
+ for x in img
+ ],
+ 0,
+ )
+
+ image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+ )
+ log_string(f"wrote {image_name}")
+
+
+######################################################################
+
+log_string(f"device {device}")
+
+
+def pruner_horizontal_green(p):
+ return not ("green" in p and ("left" in p or "right" in p))
+
+
+task = TaskPicoCLVR(
+ batch_size=args.batch_size,
+ height=args.height,
+ width=args.width,
+ nb_colors=args.nb_colors,
+ device=device,
+ pruner_train=pruner_horizontal_green
+ if args.prune_properties in {"train+eval"}
+ else None,
+ pruner_eval=(lambda p: not pruner_horizontal_green(p))
+ if args.prune_properties in {"train+eval", "eval"}
+ else None,
+)
+
+vocabulary_size = task.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+##############################
+
+model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+)
+
+model.to(device)
+
+nb_parameters = sum(p.numel() for p in model.parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+ log_string(f"not trying to load checkpoint.")
+
+else:
+ try:
+ checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+ checkpoint = torch.load(checkpoint_name)
+ nb_epochs_finished = checkpoint["nb_epochs_finished"]
+ model.load_state_dict(checkpoint["model_state"])
+ torch.set_rng_state(checkpoint["rng_state"])
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+
+ log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
+
+ except FileNotFoundError:
+ log_string("starting from scratch.")
+
+ except:
+ log_string("error when loading the checkpoint.")
+ exit(1)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split="train"):
+ token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+##############################
+
+if args.learning_rate_schedule == "cos":
+ learning_rate_schedule = {}
+ for n_epoch in range(args.nb_epochs):
+ u = n_epoch / args.nb_epochs * math.pi
+ learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
+else:
+ u = {
+ int(k): float(v)
+ for k, v in [
+ tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+ ]
+ }
+
+ learning_rate_schedule = {}
+ learning_rate = args.learning_rate
+ for n_epoch in range(args.nb_epochs):
+ if n_epoch in u:
+ learning_rate = u[n_epoch]
+ learning_rate_schedule[n_epoch] = learning_rate
+
+log_string(f"learning_rate_schedule {learning_rate_schedule}")
+
+##############################
+
+nb_samples_seen = 0
+
+if nb_epochs_finished >= nb_epochs:
+ task.produce_results(nb_epochs_finished, model)
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+ learning_rate = learning_rate_schedule[n_epoch]
+
+ log_string(f"learning_rate {learning_rate}")
+
+ if args.optim == "sgd":
+ optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
+ elif args.optim == "adam":
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+ elif args.optim == "adamw":
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+ else:
+ raise ValueError(f"Unknown optimizer {args.optim}.")
+
+ model.train()
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for input in task.batches(split="train"):
+ input = input.to(device)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ with torch.autograd.no_grad():
+
+ model.eval()
+
+ nb_test_samples, acc_test_loss = 0, 0.0
+
+ for input in task.batches(split="test"):
+ input = input.to(device)
+
+ # input, loss_masks, true_images = task.excise_last_image(input)
+ # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+ log_string(
+ f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+ )
+
+ task.produce_results(n_epoch, model)
+
+ checkpoint = {
+ "nb_epochs_finished": n_epoch + 1,
+ "model_state": model.state_dict(),
+ "rng_state": torch.get_rng_state(),
+ }
+
+ if torch.cuda.is_available():
+ checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
+ checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+ torch.save(checkpoint, checkpoint_name)
+ log_string(f"saved checkpoint {checkpoint_name}")
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
+
+
+def create_maze(h=11, w=17, nb_walls=8):
+ a, k = 0, 0
+
+ while k < nb_walls:
+ while True:
+ if a == 0:
+ m = torch.zeros(h, w, dtype=torch.int64)
+ m[0, :] = 1
+ m[-1, :] = 1
+ m[:, 0] = 1
+ m[:, -1] = 1
+
+ r = torch.rand(4)
+
+ if r[0] <= 0.5:
+ i1, i2, j = (
+ int((r[1] * h).item()),
+ int((r[2] * h).item()),
+ int((r[3] * w).item()),
+ )
+ i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
+ i1, i2 = min(i1, i2), max(i1, i2)
+ if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+ m[i1 : i2 + 1, j] = 1
+ break
+ else:
+ i, j1, j2 = (
+ int((r[1] * h).item()),
+ int((r[2] * w).item()),
+ int((r[3] * w).item()),
+ )
+ i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
+ j1, j2 = min(j1, j2), max(j1, j2)
+ if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+ m[i, j1 : j2 + 1] = 1
+ break
+ a += 1
+
+ if a > 10 * nb_walls:
+ a, k = 0, 0
+
+ k += 1
+
+ return m
+
+
+######################################################################
+
+
+def compute_distance(walls, i, j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[i, j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ d = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
+
+ dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
+
+######################################################################
+
+
+def compute_policy(walls, i, j):
+ distance = compute_distance(walls, i, j)
+ distance = distance + walls.numel() * walls
+
+ value = distance.new_full((4,) + distance.size(), walls.numel())
+ value[0, :, 1:] = distance[:, :-1]
+ value[1, :, :-1] = distance[:, 1:]
+ value[2, 1:, :] = distance[:-1, :]
+ value[3, :-1, :] = distance[1:, :]
+
+ proba = (value.min(dim=0)[0][None] == value).float()
+ proba = proba / proba.sum(dim=0)[None]
+ proba = proba * (1 - walls) + walls.float() / 4
+
+ return proba
+
+
+######################################################################
+
+
+def mark_path(walls, i, j, goal_i, goal_j):
+ policy = compute_policy(walls, goal_i, goal_j)
+ action = torch.distributions.categorical.Categorical(
+ policy.permute(1, 2, 0)
+ ).sample()
+ walls[i, j] = 4
+ n, nmax = 0, walls.numel()
+ while i != goal_i or j != goal_j:
+ di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
+ i, j = i + di, j + dj
+ assert walls[i, j] == 0
+ walls[i, j] = 4
+ n += 1
+ assert n < nmax
+
+
+def valid_paths(mazes, paths):
+ still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0
+ reached = still_ok.new_zeros(still_ok.size())
+ current, pred_current = paths.clone(), paths.new_zeros(paths.size())
+ goal = (mazes == v_goal).long()
+ while not pred_current.equal(current):
+ # print(current)
+ # print(f'{still_ok=} {reached=}')
+ pred_current.copy_(current)
+ u = (current == v_start).long()
+ possible_next = (
+ u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
+ ).long()
+ u = u[:, 1:-1, 1:-1]
+ reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
+ (current == v_path).sum((1, 2)) == 0
+ )
+ current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
+ v_start - v_path
+ ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
+ still_ok *= (current == v_start).sum((1, 2)) <= 1
+
+ return still_ok * reached
+
+
+######################################################################
+
+
+def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1):
+ mazes = torch.empty(nb, h, w, dtype=torch.int64)
+ paths = torch.empty(nb, h, w, dtype=torch.int64)
+
+ for n in range(nb):
+ maze = create_maze(h, w, nb_walls)
+ i = (1 - maze).nonzero()
+ while True:
+ start, goal = i[torch.randperm(i.size(0))[:2]]
+ if (start - goal).abs().sum() >= dist_min:
+ break
+
+ path = maze.clone()
+ mark_path(path, start[0], start[1], goal[0], goal[1])
+ maze[start[0], start[1]] = v_start
+ maze[goal[0], goal[1]] = v_goal
+ path[start[0], start[1]] = v_start
+ path[goal[0], goal[1]] = v_goal
+
+ mazes[n] = maze
+ paths[n] = path
+
+ return mazes, paths
+
+
+######################################################################
+
+
+def save_image(name, mazes, paths):
+ mazes, paths = mazes.cpu(), paths.cpu()
+
+ colors = torch.tensor(
+ [
+ [255, 255, 255], # empty
+ [0, 0, 0], # wall
+ [0, 255, 0], # start
+ [0, 0, 255], # goal
+ [255, 0, 0], # path
+ ]
+ )
+
+ mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+ paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2)
+
+ img = torch.cat((mazes.unsqueeze(1), paths.unsqueeze(1)), 1)
+ img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
+
+ torchvision.utils.save_image(img, name, padding=1, pad_value=0.5, nrow=8)
+
+
+######################################################################
+
+if __name__ == "__main__":
+
+ mazes, paths = create_maze_data(32, dist_min=10)
+ save_image("test.png", mazes, paths)
+ print(valid_paths(mazes, paths))
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ bs.x = bs.x + self.f(bs).x
+ return bs
+
+
+######################################################################
+
+# A BracketedSequence is a BxTx... tensor with a first and a nb time
+# steps to compute.
+
+# Modules able to process it expect that they will have to process a
+# first bracket starting at t=0, followed by a succession of brackets
+# that move forward in time, do not overlap, and cover the axis T with
+# no holes.
+#
+# Although it is more general, for a classical prompt-conditioned
+# auto-regressive process it will be a first bracket starting at 0 and
+# of arbitrary length for the "prompt", followed by brackets of length
+# 1 for the successive tokens.
+#
+# Modules able to process brackets may implement a cache that is
+# resetted when the input bracket starts at t=0
+
+
+class BracketedSequence:
+ def __init__(self, x, first=None, nb=None):
+ self.x = x
+ self.first = 0 if first is None else first
+ self.nb = x.size(1) if nb is None else nb
+
+ def slice(self):
+ return self.x[:, self.first : self.first + self.nb]
+
+
+######################################################################
+
+
+class CacheWrapper(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ if bs.first == 0:
+ y = self.f(bs.slice())
+ self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
+ self.cache_y[:, bs.first : bs.first + bs.nb] = y
+ else:
+ self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
+
+ bs.x = self.cache_y
+
+ return bs
+
+
+##############################
+
+
+class AddPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+ def forward(self, bs):
+ if bs.first == 0:
+ t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+ :, None
+ ]
+ j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ None, :
+ ]
+ k = j % 2
+ self.pe = torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ self.cache_y = bs.x.new(bs.x.size())
+
+ self.cache_y[:, bs.first : bs.first + bs.nb] = (
+ bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ )
+
+ bs.x = self.cache_y
+
+ return bs
+
+
+##############################
+
+
+class QKVAttention(nn.Module):
+ def __init__(
+ self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.causal = causal
+ self.attention_dropout = attention_dropout
+
+ self.w_q = randw(nb_heads, dim_qk, dim_in)
+ self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(dim_v * nb_heads, dim_in)
+
+ def forward(self, bs_q):
+ x_q = bs_q.x
+
+ if bs_q.first == 0:
+ self.cache_k = x_q.new_zeros(
+ x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
+ )
+ self.cache_v = x_q.new_zeros(
+ x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
+ )
+ self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+ q = torch.einsum(
+ "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
+ )
+ self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
+ )
+ self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
+ )
+
+ a = torch.einsum(
+ "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
+ ) / math.sqrt(self.w_q.size(1))
+
+ if self.causal:
+ if bs_q.first == 0:
+ self.cache_attzero = (
+ torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+ < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
+ )
+ a = a.masked_fill(
+ self.cache_attzero[
+ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
+ ],
+ float("-inf"),
+ )
+
+ a = a.softmax(dim=3)
+ a = F.dropout(a, self.attention_dropout, self.training)
+
+ y = torch.einsum(
+ "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
+ ).flatten(2)
+
+ self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
+
+ bs_q.x = self.cache_y
+
+ return bs_q
+
+
+##############################
+
+
+class MyGPT(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ causal=False,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+ AddPositionalEncoding(len_max),
+ )
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ causal=causal,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = CacheWrapper(
+ nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ )
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, bs):
+ bs.x = F.pad(bs.x, (1, -1))
+ bs = self.embedding(bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+
+
+######################################################################
+
+if __name__ == "__main__":
+
+ print("Basic check.")
+
+ vocabulary_size = 10
+ x = torch.randint(vocabulary_size, (9, 7))
+
+ model = MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=18,
+ dim_keys=50,
+ dim_hidden=100,
+ nb_heads=2,
+ nb_blocks=1,
+ dropout=0.1,
+ )
+
+ model.eval()
+
+ y1 = model(BracketedSequence(x)).x
+
+ y2 = torch.randn_like(y1)
+ for s in range(x.size(1)):
+ z = model(BracketedSequence(x, s, 1))
+ y2[:, s] = z.x[:, s]
+
+ # print(y1.max(dim = 2).values)
+ # print(y2.max(dim = 2).values)
+ print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
+
+######################################################################