Initial commit.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 17 Dec 2022 11:41:39 +0000 (12:41 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 17 Dec 2022 11:41:39 +0000 (12:41 +0100)
main.py [new file with mode: 0755]
mygpt.py [new file with mode: 0755]
picoclvr.py [new file with mode: 0755]
tensorstack.py [new file with mode: 0755]

diff --git a/main.py b/main.py
new file mode 100755 (executable)
index 0000000..6d9f69d
--- /dev/null
+++ b/main.py
@@ -0,0 +1,630 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, argparse, time, tqdm, itertools, os
+
+import torch, torchvision
+from torch import nn
+from torch.nn import functional as F
+
+import mygpt, tensorstack
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description="An implementation of GPT with cache to solve a toy geometric reasonning task."
+)
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--result_dir", type=str, default="results_default")
+
+parser.add_argument("--seed", type=int, default=0)
+
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--data_size", type=int, default=-1)
+
+parser.add_argument("--optim", type=str, default="adam")
+
+parser.add_argument("--learning_rate", type=float, default=1e-3)
+
+parser.add_argument(
+    "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
+)
+
+parser.add_argument("--dim_model", type=int, default=512)
+
+parser.add_argument("--dim_keys", type=int, default=64)
+
+parser.add_argument("--dim_hidden", type=int, default=2048)
+
+parser.add_argument("--nb_heads", type=int, default=8)
+
+parser.add_argument("--nb_blocks", type=int, default=12)
+
+parser.add_argument("--dropout", type=float, default=0.1)
+
+parser.add_argument("--nb_oneshot_blocks", type=int, default=-1)
+
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
+
+parser.add_argument("--overwrite_results", action="store_true", default=False)
+
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
+
+##############################
+# picoclvr options
+
+parser.add_argument("--nb_colors", type=int, default=5)
+
+parser.add_argument("--height", type=int, default=12)
+
+parser.add_argument("--width", type=int, default=16)
+
+parser.add_argument("--prune_properties", type=str, default="none")
+
+######################################################################
+
+args = parser.parse_args()
+
+assert args.prune_properties in {"none", "train+eval", "eval"}
+
+try:
+    os.mkdir(args.result_dir)
+except FileExistsError:
+    if not args.overwrite_results:
+        print(f"result directory {args.result_dir} already exists")
+        exit(1)
+
+log_file = open(os.path.join(args.result_dir, args.log_filename), "w")
+
+if args.seed >= 0:
+    # torch.backends.cudnn.deterministic = True
+    # torch.backends.cudnn.benchmark = False
+    # torch.use_deterministic_algorithms(True)
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+for n in vars(args):
+    log_string(f"args.{n} {getattr(args, n)}")
+
+######################################################################
+
+
+def masked_inplace_autoregression(
+    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+):
+
+    for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+        i = (ar_mask.sum(0) > 0).nonzero()
+        if i.min() > 0:
+            model(
+                mygpt.BracketedSequence(input, 0, i.min())
+            )  # Needed to initialize the model's cache
+        for s in range(i.min(), i.max() + 1):
+            output = model(mygpt.BracketedSequence(input, s, 1)).x
+            logits = output[:, s]
+            if forbidden_tokens is not None:
+                logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+            if args.deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                t_next = dist.sample()
+            input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+
+######################################################################
+
+
+class Task:
+    def batches(self, split="train"):
+        pass
+
+    def vocabulary_size(self):
+        pass
+
+    def produce_results(self, n_epoch, model):
+        pass
+
+
+######################################################################
+
+import picoclvr
+
+
+class TaskPicoCLVR(Task):
+
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    # Make a list of strings from a tensor
+    def detensorize(self, x):
+        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="<nul>"):
+        n = self.token2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    ######################
+    # Not the cleanest part of the code
+
+    # Extract the last image of each sequence, from the last <img>
+    # included, and set to <nul> all the tokens from the beginning of
+    # that image to the end
+    def excise_last_image(self, input):
+        t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+        nb_img_tokens = self.height * self.width + 1
+
+        input = input.clone()
+        t = (input == t_img).long()
+        tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
+        i = (t * tail_masks).nonzero(as_tuple=True)
+        j = (
+            i[0][:, None],
+            i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
+        )
+        images = self.trim(input[j])
+        input[j] = t_nul
+        loss_masks = 1 - tail_masks
+        input, loss_masks = self.trim((input, loss_masks))
+        return input, loss_masks, images
+
+    def add_true_image(self, input, images, loss_masks):
+        t_nul = self.token2id["<nul>"]
+        nb_img_tokens = self.height * self.width + 1
+        input = F.pad(input, (0, nb_img_tokens), value=t_nul)
+        loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
+        t = (input == t_nul).long()
+        i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
+        j = (
+            i[0][:, None],
+            i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
+        )
+        input[j] = images
+        loss_masks[j] = 1
+        input, loss_masks = self.trim((input, loss_masks))
+        return input, loss_masks
+
+    def add_generated_image(self, input, loss_masks, model):
+        t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+        nb_img_tokens = self.height * self.width + 1
+
+        input = F.pad(input, (0, nb_img_tokens), value=t_nul)
+        loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
+        t = (input == t_nul).long()
+        i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
+        input[i] = t_img
+
+        j = (
+            i[0][:, None],
+            i[1][:, None]
+            + 1
+            + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
+        )
+        ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
+        ar_masks[j] = 1
+        forbidden_tokens = (
+            torch.arange(self.vocabulary_size(), device=input.device) == t_nul
+        )
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                input,
+                ar_masks,
+                forbidden_tokens,
+                device=self.device,
+            )
+            model.train(t)
+
+        input, loss_masks = self.trim((input, loss_masks))
+
+        return input, loss_masks
+
+    ######################
+
+    def __init__(
+        self,
+        batch_size,
+        height,
+        width,
+        nb_colors=5,
+        device=torch.device("cpu"),
+        pruner_train=None,
+        pruner_eval=None,
+    ):
+        def generate_descr(nb, cache_suffix, pruner):
+            return picoclvr.generate(
+                nb,
+                height=self.height,
+                width=self.width,
+                nb_colors=nb_colors,
+                pruner=pruner,
+            )
+
+        self.height = height
+        self.width = width
+        self.batch_size = batch_size
+        self.device = device
+        nb = args.data_size if args.data_size > 0 else 250000
+        self.pruner_train = pruner_train
+        self.pruner_eval = pruner_eval
+
+        param = {
+            "nb": nb,
+            "height": height,
+            "width": width,
+            "nb_colors": nb_colors,
+            "batch_size": batch_size,
+            "rng_state": list(torch.get_rng_state()),
+        }
+
+        log_string(f"generating {nb} samples (can take some time)")
+        self.train_descr = generate_descr(
+            (nb * 4) // 5, "train", pruner=self.pruner_train
+        )
+        self.test_descr = generate_descr((nb * 1) // 5, "test", pruner=None)
+
+        # Build the tokenizer
+        tokens = {"<nul>", "<img>"}
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def compute_missing_properties(self, n_epoch, model, pruner=None):
+
+        acc_nb_requested_properties = []
+        acc_nb_missing_properties = []
+        acc_nb_results = 0
+
+        for input in tqdm.tqdm(
+            self.test_input.split(self.batch_size),
+            dynamic_ncols=True,
+            desc=f"test-properties",
+        ):
+            tape, loss_masks, _ = self.excise_last_image(input)
+            tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
+            result_descr = self.detensorize(tape)
+            np = picoclvr.nb_properties(
+                result_descr,
+                height=self.height,
+                width=self.width,
+                pruner=pruner,
+            )
+            nb_requested_properties, _, nb_missing_properties = zip(*np)
+            acc_nb_requested_properties += nb_requested_properties
+            acc_nb_missing_properties += nb_missing_properties
+            acc_nb_results += len(result_descr)
+
+        nb_requested_properties = sum(acc_nb_requested_properties)
+        nb_missing_properties = sum(acc_nb_missing_properties)
+
+        prefix = "" if pruner is None else "pruned_"
+        log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+        log_string(
+            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+        )
+        log_string(
+            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+        )
+
+    ######################################################################
+
+    def produce_results(self, n_epoch, model):
+
+        self.compute_missing_properties(n_epoch, model)
+
+        if self.pruner_eval is not None:
+            self.compute_missing_properties(n_epoch, model, self.pruner_eval)
+
+        nb_tokens_to_generate = self.height * self.width + 3
+        result_descr = []
+        nb_per_primer = 8
+        primer = []
+
+        for primer_descr in [
+            "red above green <sep> green top <sep> blue right of red",
+            "there is red <sep> there is yellow <sep> there is blue",
+            "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
+            "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
+        ]:
+            primer += [primer_descr] * nb_per_primer
+
+        tape = self.tensorize(primer)
+        loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
+        tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
+        result_descr = self.detensorize(tape)
+
+        np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+        acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
+        acc_nb_results = len(result_descr)
+
+        nb_requested_properties = sum(acc_nb_requested_properties)
+        nb_missing_properties = sum(acc_nb_missing_properties)
+
+        prefix = "demo_"
+        log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+        log_string(
+            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+        )
+        log_string(
+            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+        )
+
+        img = picoclvr.descr2img(
+            result_descr, [0], height=self.height, width=self.width
+        )
+
+        if img.dim() == 5:
+            if img.size(1) == 1:
+                img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
+            else:
+                img = torch.cat(
+                    [
+                        torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
+                        for x in img
+                    ],
+                    0,
+                )
+
+        image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+        torchvision.utils.save_image(
+            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+        )
+        log_string(f"wrote {image_name}")
+
+
+######################################################################
+
+log_string(f"device {device}")
+
+
+def pruner_horizontal_green(p):
+    return not ("green" in p and ("left" in p or "right" in p))
+
+
+task = TaskPicoCLVR(
+    batch_size=args.batch_size,
+    height=args.height,
+    width=args.width,
+    nb_colors=args.nb_colors,
+    device=device,
+    pruner_train=pruner_horizontal_green
+    if args.prune_properties in {"train+eval"}
+    else None,
+    pruner_eval=(lambda p: not pruner_horizontal_green(p))
+    if args.prune_properties in {"train+eval", "eval"}
+    else None,
+)
+
+vocabulary_size = task.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+##############################
+
+model = mygpt.MyGPT(
+    vocabulary_size=vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_blocks=args.nb_blocks,
+    causal=True,
+    dropout=args.dropout,
+)
+
+model.to(device)
+
+nb_parameters = sum(p.numel() for p in model.parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+    log_string(f"not trying to load checkpoint.")
+
+else:
+    try:
+        checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+        checkpoint = torch.load(checkpoint_name)
+        nb_epochs_finished = checkpoint["nb_epochs_finished"]
+        model.load_state_dict(checkpoint["model_state"])
+        torch.set_rng_state(checkpoint["rng_state"])
+        if torch.cuda.is_available():
+            torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+
+        log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
+
+    except FileNotFoundError:
+        log_string("starting from scratch.")
+
+    except:
+        log_string("error when loading the checkpoint.")
+        exit(1)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split="train"):
+    token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+##############################
+
+if args.learning_rate_schedule == "cos":
+    learning_rate_schedule = {}
+    for n_epoch in range(args.nb_epochs):
+        u = n_epoch / args.nb_epochs * math.pi
+        learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
+else:
+    u = {
+        int(k): float(v)
+        for k, v in [
+            tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+        ]
+    }
+
+    learning_rate_schedule = {}
+    learning_rate = args.learning_rate
+    for n_epoch in range(args.nb_epochs):
+        if n_epoch in u:
+            learning_rate = u[n_epoch]
+        learning_rate_schedule[n_epoch] = learning_rate
+
+log_string(f"learning_rate_schedule {learning_rate_schedule}")
+
+##############################
+
+nb_samples_seen = 0
+
+if nb_epochs_finished >= nb_epochs:
+    task.produce_results(nb_epochs_finished, model)
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+    learning_rate = learning_rate_schedule[n_epoch]
+
+    log_string(f"learning_rate {learning_rate}")
+
+    if args.optim == "sgd":
+        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
+    elif args.optim == "adam":
+        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+    elif args.optim == "adamw":
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+    else:
+        raise ValueError(f"Unknown optimizer {args.optim}.")
+
+    model.train()
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    for input in task.batches(split="train"):
+        input = input.to(device)
+        output = model(mygpt.BracketedSequence(input)).x
+        loss = F.cross_entropy(output.transpose(1, 2), input)
+        acc_train_loss += loss.item() * input.size(0)
+        nb_train_samples += input.size(0)
+        nb_samples_seen += input.size(0)
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+    with torch.autograd.no_grad():
+
+        model.eval()
+
+        nb_test_samples, acc_test_loss = 0, 0.0
+
+        for input in task.batches(split="test"):
+            input = input.to(device)
+
+            # input, loss_masks, true_images = task.excise_last_image(input)
+            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
+
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_test_loss += loss.item() * input.size(0)
+            nb_test_samples += input.size(0)
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+        log_string(
+            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+        )
+
+        task.produce_results(n_epoch, model)
+
+    checkpoint = {
+        "nb_epochs_finished": n_epoch + 1,
+        "model_state": model.state_dict(),
+        "rng_state": torch.get_rng_state(),
+    }
+
+    if torch.cuda.is_available():
+        checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
+    checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+    torch.save(checkpoint, checkpoint_name)
+    log_string(f"saved checkpoint {checkpoint_name}")
+
+######################################################################
diff --git a/mygpt.py b/mygpt.py
new file mode 100755 (executable)
index 0000000..0ed7eb0
--- /dev/null
+++ b/mygpt.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+######################################################################
+
+# A BracketedSequence is a BxTx... tensor with a first and a nb time
+# steps to compute.
+
+# Modules able to process it expect that they will have to process a
+# first bracket starting at t=0, followed by a succession of brackets
+# that move forward in time, do not overlap, and cover the axis T with
+# no holes.
+#
+# Although it is more general, for a classical prompt-conditioned
+# auto-regressive process it will be a first bracket starting at 0 and
+# of arbitrary length for the "prompt", followed by brackets of length
+# 1 for the successive tokens.
+#
+# Modules able to process brackets may implement a cache that is
+# resetted when the input bracket starts at t=0
+
+
+class BracketedSequence:
+    def __init__(self, x, first=None, nb=None):
+        self.x = x
+        self.first = 0 if first is None else first
+        self.nb = x.size(1) if nb is None else nb
+
+    def slice(self):
+        return self.x[:, self.first : self.first + self.nb]
+
+
+######################################################################
+
+
+class CacheWrapper(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        if bs.first == 0:
+            y = self.f(bs.slice())
+            self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
+            self.cache_y[:, bs.first : bs.first + bs.nb] = y
+        else:
+            self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
+
+        bs.x = self.cache_y
+
+        return bs
+
+
+##############################
+
+
+class AddPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+    def forward(self, bs):
+        if bs.first == 0:
+            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+                :, None
+            ]
+            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+                None, :
+            ]
+            k = j % 2
+            self.pe = torch.sin(
+                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+            )
+            self.cache_y = bs.x.new(bs.x.size())
+
+        self.cache_y[:, bs.first : bs.first + bs.nb] = (
+            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+        )
+
+        bs.x = self.cache_y
+
+        return bs
+
+
+##############################
+
+
+class QKVAttention(nn.Module):
+    def __init__(
+        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.causal = causal
+        self.attention_dropout = attention_dropout
+
+        self.w_q = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
+
+    def forward(self, bs_q, x_kv=None):
+        x_q = bs_q.x
+        if x_kv is None:
+            x_kv = x_q
+
+        if bs_q.first == 0:
+            self.cache_k = x_q.new_zeros(
+                x_q.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1)
+            )
+            self.cache_v = x_q.new_zeros(
+                x_q.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1)
+            )
+            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+        q = torch.einsum(
+            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
+        )
+        self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_kv[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
+        )
+        self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_kv[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
+        )
+
+        a = torch.einsum(
+            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
+        ) / math.sqrt(self.w_q.size(1))
+
+        if self.causal:
+            if bs_q.first == 0:
+                self.cache_attzero = (
+                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+                    < torch.arange(x_kv.size(1), device=q.device)[None, None, None, :]
+                )
+            a = a.masked_fill(
+                self.cache_attzero[
+                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
+                ],
+                float("-inf"),
+            )
+
+        a = a.softmax(dim=3)
+        a = F.dropout(a, self.attention_dropout, self.training)
+
+        y = torch.einsum(
+            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
+        ).flatten(2)
+
+        self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
+
+        bs_q.x = self.cache_y
+
+        return bs_q
+
+
+##############################
+
+
+class MyGPT(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        causal=False,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+            AddPositionalEncoding(len_max),
+        )
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    QKVAttention(
+                        dim_in=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        causal=causal,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                        nn.ReLU(),
+                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                        nn.Dropout(dropout),
+                    ),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = CacheWrapper(
+            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        )
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, bs):
+        bs.x = F.pad(bs.x, (1, -1))
+        bs = self.embedding(bs)
+        bs = self.trunk(bs)
+        bs = self.readout(bs)
+        return bs
+
+
+######################################################################
+
+if __name__ == "__main__":
+
+    print("Basic check.")
+
+    vocabulary_size = 10
+    x = torch.randint(vocabulary_size, (9, 7))
+
+    model = MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=18,
+        dim_keys=50,
+        dim_hidden=100,
+        nb_heads=2,
+        nb_blocks=1,
+        dropout=0.1,
+    )
+
+    model.eval()
+
+    y1 = model(BracketedSequence(x)).x
+
+    y2 = torch.randn_like(y1)
+    for s in range(x.size(1)):
+        z = model(BracketedSequence(x, s, 1))
+        y2[:, s] = z.x[:, s]
+
+    # print(y1.max(dim = 2).values)
+    # print(y2.max(dim = 2).values)
+    print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
+
+######################################################################
diff --git a/picoclvr.py b/picoclvr.py
new file mode 100755 (executable)
index 0000000..94c0f88
--- /dev/null
@@ -0,0 +1,511 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+import torch.nn.functional as F
+
+colors = [
+    [255, 255, 255],
+    [255, 0, 0],
+    [0, 128, 0],
+    [0, 0, 255],
+    [255, 255, 0],
+    [0, 0, 0],
+    [128, 0, 0],
+    [139, 0, 0],
+    [165, 42, 42],
+    [178, 34, 34],
+    [220, 20, 60],
+    [255, 99, 71],
+    [255, 127, 80],
+    [205, 92, 92],
+    [240, 128, 128],
+    [233, 150, 122],
+    [250, 128, 114],
+    [255, 160, 122],
+    [255, 69, 0],
+    [255, 140, 0],
+    [255, 165, 0],
+    [255, 215, 0],
+    [184, 134, 11],
+    [218, 165, 32],
+    [238, 232, 170],
+    [189, 183, 107],
+    [240, 230, 140],
+    [128, 128, 0],
+    [154, 205, 50],
+    [85, 107, 47],
+    [107, 142, 35],
+    [124, 252, 0],
+    [127, 255, 0],
+    [173, 255, 47],
+    [0, 100, 0],
+    [34, 139, 34],
+    [0, 255, 0],
+    [50, 205, 50],
+    [144, 238, 144],
+    [152, 251, 152],
+    [143, 188, 143],
+    [0, 250, 154],
+    [0, 255, 127],
+    [46, 139, 87],
+    [102, 205, 170],
+    [60, 179, 113],
+    [32, 178, 170],
+    [47, 79, 79],
+    [0, 128, 128],
+    [0, 139, 139],
+    [0, 255, 255],
+    [0, 255, 255],
+    [224, 255, 255],
+    [0, 206, 209],
+    [64, 224, 208],
+    [72, 209, 204],
+    [175, 238, 238],
+    [127, 255, 212],
+    [176, 224, 230],
+    [95, 158, 160],
+    [70, 130, 180],
+    [100, 149, 237],
+    [0, 191, 255],
+    [30, 144, 255],
+    [173, 216, 230],
+    [135, 206, 235],
+    [135, 206, 250],
+    [25, 25, 112],
+    [0, 0, 128],
+    [0, 0, 139],
+    [0, 0, 205],
+    [65, 105, 225],
+    [138, 43, 226],
+    [75, 0, 130],
+    [72, 61, 139],
+    [106, 90, 205],
+    [123, 104, 238],
+    [147, 112, 219],
+    [139, 0, 139],
+    [148, 0, 211],
+    [153, 50, 204],
+    [186, 85, 211],
+    [128, 0, 128],
+    [216, 191, 216],
+    [221, 160, 221],
+    [238, 130, 238],
+    [255, 0, 255],
+    [218, 112, 214],
+    [199, 21, 133],
+    [219, 112, 147],
+    [255, 20, 147],
+    [255, 105, 180],
+    [255, 182, 193],
+    [255, 192, 203],
+    [250, 235, 215],
+    [245, 245, 220],
+    [255, 228, 196],
+    [255, 235, 205],
+    [245, 222, 179],
+    [255, 248, 220],
+    [255, 250, 205],
+    [250, 250, 210],
+    [255, 255, 224],
+    [139, 69, 19],
+    [160, 82, 45],
+    [210, 105, 30],
+    [205, 133, 63],
+    [244, 164, 96],
+    [222, 184, 135],
+    [210, 180, 140],
+    [188, 143, 143],
+    [255, 228, 181],
+    [255, 222, 173],
+    [255, 218, 185],
+    [255, 228, 225],
+    [255, 240, 245],
+    [250, 240, 230],
+    [253, 245, 230],
+    [255, 239, 213],
+    [255, 245, 238],
+    [245, 255, 250],
+    [112, 128, 144],
+    [119, 136, 153],
+    [176, 196, 222],
+    [230, 230, 250],
+    [255, 250, 240],
+    [240, 248, 255],
+    [248, 248, 255],
+    [240, 255, 240],
+    [255, 255, 240],
+    [240, 255, 255],
+    [255, 250, 250],
+    [192, 192, 192],
+    [220, 220, 220],
+    [245, 245, 245],
+]
+
+color_names = [
+    "white",
+    "red",
+    "green",
+    "blue",
+    "yellow",
+    "black",
+    "maroon",
+    "dark_red",
+    "brown",
+    "firebrick",
+    "crimson",
+    "tomato",
+    "coral",
+    "indian_red",
+    "light_coral",
+    "dark_salmon",
+    "salmon",
+    "light_salmon",
+    "orange_red",
+    "dark_orange",
+    "orange",
+    "gold",
+    "dark_golden_rod",
+    "golden_rod",
+    "pale_golden_rod",
+    "dark_khaki",
+    "khaki",
+    "olive",
+    "yellow_green",
+    "dark_olive_green",
+    "olive_drab",
+    "lawn_green",
+    "chartreuse",
+    "green_yellow",
+    "dark_green",
+    "forest_green",
+    "lime",
+    "lime_green",
+    "light_green",
+    "pale_green",
+    "dark_sea_green",
+    "medium_spring_green",
+    "spring_green",
+    "sea_green",
+    "medium_aqua_marine",
+    "medium_sea_green",
+    "light_sea_green",
+    "dark_slate_gray",
+    "teal",
+    "dark_cyan",
+    "aqua",
+    "cyan",
+    "light_cyan",
+    "dark_turquoise",
+    "turquoise",
+    "medium_turquoise",
+    "pale_turquoise",
+    "aqua_marine",
+    "powder_blue",
+    "cadet_blue",
+    "steel_blue",
+    "corn_flower_blue",
+    "deep_sky_blue",
+    "dodger_blue",
+    "light_blue",
+    "sky_blue",
+    "light_sky_blue",
+    "midnight_blue",
+    "navy",
+    "dark_blue",
+    "medium_blue",
+    "royal_blue",
+    "blue_violet",
+    "indigo",
+    "dark_slate_blue",
+    "slate_blue",
+    "medium_slate_blue",
+    "medium_purple",
+    "dark_magenta",
+    "dark_violet",
+    "dark_orchid",
+    "medium_orchid",
+    "purple",
+    "thistle",
+    "plum",
+    "violet",
+    "magenta",
+    "orchid",
+    "medium_violet_red",
+    "pale_violet_red",
+    "deep_pink",
+    "hot_pink",
+    "light_pink",
+    "pink",
+    "antique_white",
+    "beige",
+    "bisque",
+    "blanched_almond",
+    "wheat",
+    "corn_silk",
+    "lemon_chiffon",
+    "light_golden_rod_yellow",
+    "light_yellow",
+    "saddle_brown",
+    "sienna",
+    "chocolate",
+    "peru",
+    "sandy_brown",
+    "burly_wood",
+    "tan",
+    "rosy_brown",
+    "moccasin",
+    "navajo_white",
+    "peach_puff",
+    "misty_rose",
+    "lavender_blush",
+    "linen",
+    "old_lace",
+    "papaya_whip",
+    "sea_shell",
+    "mint_cream",
+    "slate_gray",
+    "light_slate_gray",
+    "light_steel_blue",
+    "lavender",
+    "floral_white",
+    "alice_blue",
+    "ghost_white",
+    "honeydew",
+    "ivory",
+    "azure",
+    "snow",
+    "silver",
+    "gainsboro",
+    "white_smoke",
+]
+
+color_id = dict([(n, k) for k, n in enumerate(color_names)])
+color_tokens = dict([(n, c) for n, c in zip(color_names, colors)])
+
+######################################################################
+
+
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+    s = []
+
+    for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
+        s += [f"there is {c_r}"]
+
+        if square_i[r] >= height - height // 3:
+            s += [f"{c_r} bottom"]
+        if square_i[r] < height // 3:
+            s += [f"{c_r} top"]
+        if square_j[r] >= width - width // 3:
+            s += [f"{c_r} right"]
+        if square_j[r] < width // 3:
+            s += [f"{c_r} left"]
+
+        for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
+            if square_i[r] > square_i[t]:
+                s += [f"{c_r} below {c_t}"]
+            if square_i[r] < square_i[t]:
+                s += [f"{c_r} above {c_t}"]
+            if square_j[r] > square_j[t]:
+                s += [f"{c_r} right of {c_t}"]
+            if square_j[r] < square_j[t]:
+                s += [f"{c_r} left of {c_t}"]
+
+    return s
+
+
+######################################################################
+
+# Generates sequences
+
+
+def generate(
+    nb,
+    height,
+    width,
+    max_nb_squares=5,
+    max_nb_properties=10,
+    nb_colors=5,
+    pruner=None,
+):
+
+    assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
+
+    descr = []
+
+    for n in range(nb):
+
+        nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+        square_position = torch.randperm(height * width)[:nb_squares]
+
+        # color 0 is white and reserved for the background
+        square_c = torch.randperm(nb_colors)[:nb_squares] + 1
+        square_i = square_position.div(width, rounding_mode="floor")
+        square_j = square_position % width
+
+        img = [0] * height * width
+        for k in range(nb_squares):
+            img[square_position[k]] = square_c[k]
+
+        # generates all the true properties
+
+        s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
+
+        if pruner is not None:
+            s = list(filter(pruner, s))
+
+        # picks at most max_nb_properties at random
+
+        nb_properties = torch.randint(max_nb_properties, (1,)) + 1
+        s = (
+            " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
+            + " <img> "
+            + " ".join([f"{color_names[n]}" for n in img])
+        )
+
+        descr += [s]
+
+    return descr
+
+
+######################################################################
+
+# Extracts the image after <img> in descr as a 1x3xHxW tensor
+
+
+def descr2img(descr, n, height, width):
+
+    if type(descr) == list:
+        return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
+
+    if type(n) == list:
+        return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
+            0
+        )
+
+    def token2color(t):
+        try:
+            return color_tokens[t]
+        except KeyError:
+            return [128, 128, 128]
+
+    d = descr.split("<img>")
+    d = d[n + 1] if len(d) > n + 1 else ""
+    d = d.strip().split(" ")[: height * width]
+    d = d + ["<unk>"] * (height * width - len(d))
+    d = [token2color(t) for t in d]
+    img = torch.tensor(d).permute(1, 0)
+    img = img.reshape(1, 3, height, width)
+
+    return img
+
+
+######################################################################
+
+# Returns all the properties of the image after <img> in descr
+
+
+def descr2properties(descr, height, width):
+
+    if type(descr) == list:
+        return [descr2properties(d, height, width) for d in descr]
+
+    d = descr.split("<img>")
+    d = d[-1] if len(d) > 1 else ""
+    d = d.strip().split(" ")[: height * width]
+    if len(d) != height * width:
+        return []
+
+    seen = {}
+    for k, x in enumerate(d):
+        if x != color_names[0]:
+            if x in color_tokens:
+                if x in seen:
+                    return []
+            else:
+                return []
+            seen[x] = (color_id[x], k // width, k % width)
+
+    square_infos = tuple(zip(*seen.values()))
+
+    if square_infos:
+        square_c = torch.tensor(square_infos[0])
+        square_i = torch.tensor(square_infos[1])
+        square_j = torch.tensor(square_infos[2])
+    else:
+        square_c = torch.tensor([])
+        square_i = torch.tensor([])
+        square_j = torch.tensor([])
+
+    s = all_properties(height, width, len(seen), square_i, square_j, square_c)
+
+    return s
+
+
+######################################################################
+
+# Returns a triplet composed of (1) the total number of properties
+# before <img> in descr, (2) the total number of properties the image
+# after <img> verifies, and (3) the number of properties in (1) not in
+# (2)
+
+
+def nb_properties(descr, height, width, pruner=None):
+
+    if type(descr) == list:
+        return [nb_properties(d, height, width, pruner) for d in descr]
+
+    d = descr.split("<img>", 1)
+    if len(d) == 0:
+        return 0
+    d = d[0].strip().split("<sep>")
+    d = [x.strip() for x in d]
+
+    all_properties = set(descr2properties(descr, height, width))
+
+    if pruner is None:
+        requested_properties = set(d)
+    else:
+        requested_properties = set(filter(pruner, d))
+
+    missing_properties = requested_properties - all_properties
+
+    return (len(requested_properties), len(all_properties), len(missing_properties))
+
+
+######################################################################
+
+if __name__ == "__main__":
+    for n in range(16):
+        descr = generate(nb=1, height=12, width=16)
+
+        print(nb_properties(descr, height=12, width=16))
+
+        with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
+            for d in descr:
+                f.write(f"{d}\n\n")
+
+        img = descr2img(descr, n=0, height=12, width=16)
+        if img.size(0) == 1:
+            img = F.pad(img, (1, 1, 1, 1), value=64)
+
+        torchvision.utils.save_image(
+            img / 255.0,
+            f"picoclvr_example_{n:02d}.png",
+            padding=1,
+            nrow=4,
+            pad_value=0.8,
+        )
+
+    import time
+
+    start_time = time.perf_counter()
+    descr = generate(nb=1000, height=12, width=16)
+    end_time = time.perf_counter()
+    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+
+######################################################################
diff --git a/tensorstack.py b/tensorstack.py
new file mode 100755 (executable)
index 0000000..3218be1
--- /dev/null
@@ -0,0 +1,62 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+from torch import Tensor
+
+import sys
+
+
+def exception_hook(exc_type, exc_value, tb):
+    r"""Hacks the call stack message to show all the local variables in
+    case of RuntimeError or ValueError, and prints tensors as shape,
+    dtype and device.
+
+    """
+
+    repr_orig = Tensor.__repr__
+    Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
+
+    while tb:
+        print("--------------------------------------------------\n")
+        filename = tb.tb_frame.f_code.co_filename
+        name = tb.tb_frame.f_code.co_name
+        line_no = tb.tb_lineno
+        print(f'  File "{filename}", line {line_no}, in {name}')
+        print(open(filename, "r").readlines()[line_no - 1])
+
+        if exc_type in {RuntimeError, ValueError}:
+            for n, v in tb.tb_frame.f_locals.items():
+                print(f"  {n} -> {v}")
+
+        print()
+        tb = tb.tb_next
+
+    Tensor.__repr__ = repr_orig
+
+    print(f"{exc_type.__name__}: {exc_value}")
+
+
+sys.excepthook = exception_hook
+
+######################################################################
+
+if __name__ == "__main__":
+
+    import torch
+
+    def dummy(a, b):
+        print(a @ b)
+
+    def blah(a, b):
+        c = b + b
+        dummy(a, c)
+
+    mmm = torch.randn(2, 3)
+    xxx = torch.randn(3)
+    # print(xxx@mmm)
+    blah(mmm, xxx)
+    blah(xxx, mmm)