--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+
+def random_var(nb_variables=None, variables=None):
+ if variables is None:
+ return chr(ord("A") + torch.randint(nb_variables, (1,)).item())
+ else:
+ l = list(variables)
+ return l[torch.randint(len(l), (1,)).item()]
+
+
+def random_expr(variables, operand_max, budget):
+ if budget <= 5:
+ op = torch.randint(2, (1,)).item()
+ if op == 0 and len(variables) > 0:
+ return random_var(variables=variables)
+ else:
+ return str(torch.randint(operand_max + 1, (1,)).item())
+ else:
+ op = torch.randint(3, (1,)).item()
+ if op == 0:
+ e = random_expr(variables, operand_max, budget - 2)
+ if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"):
+ return "(" + e + ")"
+ else:
+ return e
+ else:
+ b = 2 + torch.randint(budget - 5, (1,)).item()
+ e1 = random_expr(variables, operand_max, b)
+ e2 = random_expr(variables, operand_max, budget - b - 1)
+ if op == 1:
+ return e1 + "+" + e2
+ elif op == 2:
+ return e1 + "*" + e2
+
+
+def generate_program(nb_variables, operand_max, length):
+ s = ""
+ variables = set()
+
+ while len(s) < length:
+ v = random_var(nb_variables=nb_variables)
+ s += v + "=" + random_expr(variables, operand_max, budget=20) + ";"
+ variables.add(v)
+
+ return s, variables
+
+
+def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99):
+ assert nb_variables <= 26
+ sequences = []
+
+ for n in range(nb):
+ # We take length itself half of the time, and uniform between
+ # 1 and length otherwise. The actual length can be slightly
+ # greater
+
+ l = min(length, 1 + torch.randint(length * 2, (1,)).item())
+ result = None
+ while result == None or max(result.values()) > result_max:
+ p, v = generate_program(nb_variables, operand_max, l)
+ v = ", ".join(['"' + v + '": ' + v for v in v])
+ ldict = {}
+ exec(p + "result={" + v + "}", globals(), ldict)
+ result = ldict["result"]
+
+ k = list(result.keys())
+ k.sort()
+ sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k]))
+
+ return sequences
+
+
+def extract_results(seq):
+ f = lambda a: (a[0], -1 if a[1] == "" else int(a[1]))
+ results = [
+ dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)])
+ for s in seq
+ ]
+ return results
+
+
+if __name__ == "__main__":
+ import time
+
+ start_time = time.perf_counter()
+ sequences = generate_sequences(1000, length=40)
+ end_time = time.perf_counter()
+ for s in sequences[:10]:
+ print(s)
+ print(f"{len(sequences) / (end_time - start_time):.02f} samples per second")
+
+ print(extract_results(sequences[:10]))
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch
+import sys, contextlib
+
+import torch
+from torch import Tensor
+
+######################################################################
+
+
+@contextlib.contextmanager
+def evaluation(*models):
+ with torch.inference_mode():
+ t = [(m, m.training) for m in models]
+ for m in models:
+ m.train(False)
+ yield
+ for m, u in t:
+ m.train(u)
+
+
+######################################################################
+
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+def hasNaN(x):
+ if torch.is_tensor(x):
+ return x.numel() > 0 and x.isnan().max()
+ else:
+ try:
+ return any([hasNaN(y) for y in x])
+ except TypeError:
+ return False
+
+
+class NaNDetect(TorchDispatchMode):
+ def __torch_dispatch__(self, func, types, args, kwargs=None):
+ kwargs = kwargs or {}
+ res = func(*args, **kwargs)
+
+ if hasNaN(res):
+ raise RuntimeError(
+ f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
+ )
+ return res
+
+
+######################################################################
+
+
+def exception_hook(exc_type, exc_value, tb):
+ r"""Hacks the call stack message to show all the local variables
+ in case of relevant error, and prints tensors as shape, dtype and
+ device.
+
+ """
+
+ repr_orig = Tensor.__repr__
+ Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
+
+ while tb:
+ print("--------------------------------------------------\n")
+ filename = tb.tb_frame.f_code.co_filename
+ name = tb.tb_frame.f_code.co_name
+ line_no = tb.tb_lineno
+ print(f' File "{filename}", line {line_no}, in {name}')
+ print(open(filename, "r").readlines()[line_no - 1])
+
+ if exc_type in {RuntimeError, ValueError, IndexError, TypeError}:
+ for n, v in tb.tb_frame.f_locals.items():
+ print(f" {n} -> {v}")
+
+ print()
+ tb = tb.tb_next
+
+ Tensor.__repr__ = repr_orig
+
+ print(f"{exc_type.__name__}: {exc_value}")
+
+
+def activate_tensorstack():
+ sys.excepthook = exception_hook
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import torch
+
+ def dummy(a, b):
+ print(a @ b)
+
+ def blah(a, b):
+ c = b + b
+ dummy(a, c)
+
+ mmm = torch.randn(2, 3)
+ xxx = torch.randn(3)
+ # print(xxx@mmm)
+ blah(mmm, xxx)
+ blah(xxx, mmm)
--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+import cairo
+
+
+######################################################################
+
+
+def save_attention_image(
+ # image to save
+ filename,
+ tokens_input,
+ tokens_output,
+ # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1
+ attention_matrices,
+ # do not draw links with a lesser attention
+ min_link_attention=0,
+ # draw only the strongest links necessary so that their summed
+ # attention is above min_total_attention
+ min_total_attention=None,
+ # draw only the top k links
+ k_top=None,
+ # the purely graphical settings
+ curved=True,
+ pixel_scale=8,
+ token_gap=15,
+ layer_gap=25,
+ y_eps=0.5,
+ padding=10,
+):
+ if k_top is not None:
+ am = []
+ for m in attention_matrices:
+ am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
+ attention_matrices = am
+
+ if min_total_attention is not None:
+ am = []
+ for m in attention_matrices:
+ s = m.sort(dim=-1)
+ m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
+ b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
+ am.append(m * b)
+
+ surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
+
+ ctx = cairo.Context(surface)
+ ctx.scale(pixel_scale, pixel_scale)
+
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.set_font_size(4.0)
+ # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+
+ x, y = 0, 0
+
+ ctx.set_line_width(0.25)
+ for d in range(len(attention_matrices)):
+ at = attention_matrices[d].to("cpu")
+ ni = torch.arange(at.size(0))[:, None].expand_as(at)
+ nj = torch.arange(at.size(1))[None, :].expand_as(at)
+ at = at.flatten()
+ o = at.sort().indices
+ at = at[o]
+ ni = ni.flatten()[o]
+ nj = nj.flatten()[o]
+ for i, j, a in zip(ni, nj, at):
+ if a > 0 and a >= min_link_attention:
+ c = 1 - a.item()
+ ctx.set_source_rgb(c, c, c)
+ ax, ay = j * token_gap, y - y_eps
+ ctx.move_to(ax, ay)
+ dx, dy = i * token_gap, y - layer_gap + y_eps
+ if curved:
+ bx, by = ax, ay - layer_gap * 0.5
+ cx, cy = dx, dy + layer_gap * 0.5
+ ctx.curve_to(bx, by, cx, cy, dx, dy)
+ else:
+ ctx.line_to(dx, dy)
+ ctx.stroke()
+ y -= layer_gap
+
+ for d in range(0, len(attention_matrices) + 1):
+ n = (
+ attention_matrices[0].size(-1)
+ if d == 0
+ else attention_matrices[d - 1].size(-2)
+ )
+ for n in range(n):
+ xc, yc = n * token_gap, -d * layer_gap
+ ctx.set_source_rgb(1.0, 1.0, 1.0)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.fill()
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
+ ctx.fill()
+
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+
+ for k, t in enumerate(tokens_input):
+ s = str(t)
+ (
+ x_bearing,
+ y_bearing,
+ width_t,
+ height_t,
+ x_advance,
+ y_advance,
+ ) = ctx.text_extents(s)
+ ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5)
+ ctx.show_text(s)
+
+ for k, t in enumerate(tokens_output):
+ s = str(t)
+ (
+ x_bearing,
+ y_bearing,
+ width_t,
+ height_t,
+ x_advance,
+ y_advance,
+ ) = ctx.text_extents(s)
+ ctx.move_to(
+ k * token_gap - width_t / 2,
+ -token_gap / 5 - len(attention_matrices) * layer_gap,
+ )
+ ctx.show_text(s)
+
+ x, y, width, height = surface.ink_extents()
+ x -= padding
+ y -= padding
+ width += 2 * padding
+ height += 2 * padding
+ pdf_surface = cairo.PDFSurface(filename, width, height)
+ ctx_pdf = cairo.Context(pdf_surface)
+ ctx_pdf.set_source_surface(surface, -x, -y)
+ ctx_pdf.paint()
+ pdf_surface.finish()
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import mygpt
+
+ tokens_output = ["<wat>", "-", 3, 4, "<end>"]
+ tokens_input = [""] + tokens_output[:-1]
+
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, len(tokens_input)))
+
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
+ nb_heads=2,
+ nb_blocks=5,
+ dropout=0.1,
+ causal=True,
+ )
+
+ model.eval()
+ model.record_attention()
+
+ y1 = model(mygpt.BracketedSequence(x)).x
+
+ attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
+
+ # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]]
+
+ save_attention_image(
+ "attention.pdf",
+ tokens_input,
+ tokens_output,
+ attention_matrices,
+ # k_top=2,
+ min_total_attention=0.9,
+ )
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+import torch, torchvision
+import torch.nn.functional as F
+
+name_shapes = ["A", "B", "C", "D", "E", "F"]
+
+name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
+
+######################################################################
+
+
+class GridFactory:
+ def __init__(
+ self,
+ size=6,
+ max_nb_items=4,
+ max_nb_transformations=3,
+ nb_questions=4,
+ ):
+ assert size % 2 == 0
+ self.size = size
+ self.max_nb_items = max_nb_items
+ self.max_nb_transformations = max_nb_transformations
+ self.nb_questions = nb_questions
+
+ def generate_scene(self):
+ nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
+ col = torch.full((self.size * self.size,), -1)
+ shp = torch.full((self.size * self.size,), -1)
+ a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
+ col[:nb_items] = a % len(name_colors)
+ shp[:nb_items] = a // len(name_colors)
+ i = torch.randperm(self.size * self.size)
+ col = col[i]
+ shp = shp[i]
+ return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
+
+ def random_transformations(self, scene):
+ col, shp = scene
+
+ descriptions = []
+ nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+ transformations = torch.randint(5, (nb_transformations,))
+
+ for t in transformations:
+ if t == 0:
+ col, shp = col.flip(0), shp.flip(0)
+ descriptions += ["<chg> vertical flip"]
+ elif t == 1:
+ col, shp = col.flip(1), shp.flip(1)
+ descriptions += ["<chg> horizontal flip"]
+ elif t == 2:
+ col, shp = col.flip(0).t(), shp.flip(0).t()
+ descriptions += ["<chg> rotate 90 degrees"]
+ elif t == 3:
+ col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+ descriptions += ["<chg> rotate 180 degrees"]
+ elif t == 4:
+ col, shp = col.flip(1).t(), shp.flip(1).t()
+ descriptions += ["<chg> rotate 270 degrees"]
+
+ col, shp = col.contiguous(), shp.contiguous()
+
+ return (col, shp), descriptions
+
+ def print_scene(self, scene):
+ col, shp = scene
+
+ # for i in range(self.size):
+ # for j in range(self.size):
+ # if col[i,j] >= 0:
+ # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+
+ for i in range(self.size):
+ for j in range(self.size):
+ if col[i, j] >= 0:
+ print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+ elif j == 0:
+ print(" +", end="")
+ else:
+ print("-+", end="")
+ if j < self.size - 1:
+ print("--", end="")
+ else:
+ print("")
+ if i < self.size - 1:
+ for j in range(self.size - 1):
+ print(" | ", end="")
+ print(" |")
+
+ def grid_positions(self, scene):
+ col, shp = scene
+
+ properties = []
+
+ for i in range(self.size):
+ for j in range(self.size):
+ if col[i, j] >= 0:
+ n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+ properties += [f"a {n} at {i} {j}"]
+
+ return properties
+
+ def all_properties(self, scene):
+ col, shp = scene
+
+ properties = []
+
+ for i1 in range(self.size):
+ for j1 in range(self.size):
+ if col[i1, j1] >= 0:
+ n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+ properties += [f"there is a {n1}"]
+ if i1 < self.size // 2:
+ properties += [f"a {n1} is in the top half"]
+ if i1 >= self.size // 2:
+ properties += [f"a {n1} is in the bottom half"]
+ if j1 < self.size // 2:
+ properties += [f"a {n1} is in the left half"]
+ if j1 >= self.size // 2:
+ properties += [f"a {n1} is in the right half"]
+ for i2 in range(self.size):
+ for j2 in range(self.size):
+ if col[i2, j2] >= 0:
+ n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+ if i1 > i2:
+ properties += [f"a {n1} is below a {n2}"]
+ if i1 < i2:
+ properties += [f"a {n1} is above a {n2}"]
+ if j1 > j2:
+ properties += [f"a {n1} is right of a {n2}"]
+ if j1 < j2:
+ properties += [f"a {n1} is left of a {n2}"]
+ if abs(i1 - i2) + abs(j1 - j2) == 1:
+ properties += [f"a {n1} is next to a {n2}"]
+
+ return properties
+
+ def generate_scene_and_questions(self):
+ while True:
+ while True:
+ start_scene = self.generate_scene()
+ scene, transformations = self.random_transformations(start_scene)
+ true = self.all_properties(scene)
+ if len(true) >= self.nb_questions:
+ break
+
+ for a in range(10):
+ col, shp = scene
+ col, shp = col.view(-1), shp.view(-1)
+ p = torch.randperm(col.size(0))
+ col, shp = col[p], shp[p]
+ other_scene = (
+ col.view(self.size, self.size),
+ shp.view(self.size, self.size),
+ )
+
+ false = self.all_properties(other_scene)
+
+ # We sometime add properties from a totally different
+ # scene to have negative "there is a xxx xxx"
+ # properties
+ if torch.rand(1).item() < 0.2:
+ other_scene = self.generate_scene()
+ false += self.all_properties(other_scene)
+
+ false = list(set(false) - set(true))
+ if len(false) >= self.nb_questions:
+ break
+
+ if a < 10:
+ break
+
+ true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
+ false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
+ true = ["<prop> " + q + " <ans> true" for q in true]
+ false = ["<prop> " + q + " <ans> false" for q in false]
+
+ union = true + false
+ questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
+
+ result = " ".join(
+ ["<obj> " + x for x in self.grid_positions(start_scene)]
+ + transformations
+ + questions
+ )
+
+ return start_scene, scene, result
+
+ def generate_samples(self, nb, progress_bar=None):
+ result = []
+
+ r = range(nb)
+ if progress_bar is not None:
+ r = progress_bar(r)
+
+ for _ in r:
+ result.append(self.generate_scene_and_questions()[2])
+
+ return result
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import time
+
+ grid_factory = GridFactory()
+
+ # start_time = time.perf_counter()
+ # samples = grid_factory.generate_samples(10000)
+ # end_time = time.perf_counter()
+ # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+
+ start_scene, scene, questions = grid_factory.generate_scene_and_questions()
+ print()
+ print("-- Original scene -----------------------------")
+ print()
+ grid_factory.print_scene(start_scene)
+ print()
+ print("-- Transformed scene --------------------------")
+ print()
+ grid_factory.print_scene(scene)
+ print()
+ print("-- Sequence -----------------------------------")
+ print()
+ print(questions)
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, argparse, time, tqdm, os, datetime, warnings
+
+import torch, torchvision
+from torch import nn
+from torch.nn import functional as F
+
+import ffutils
+import mygpt, tasks, problems
+
+######################################################################
+
+if torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description="An implementation of GPT with cache.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+
+parser.add_argument(
+ "--task",
+ type=str,
+ default="twotargets",
+ help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+)
+
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+
+parser.add_argument("--result_dir", type=str, default=None)
+
+parser.add_argument("--seed", type=int, default=0)
+
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+
+########################################
+
+parser.add_argument("--nb_epochs", type=int, default=50)
+
+parser.add_argument("--batch_size", type=int, default=None)
+
+parser.add_argument("--nb_train_samples", type=int, default=None)
+
+parser.add_argument("--nb_test_samples", type=int, default=None)
+
+parser.add_argument("--optim", type=str, default="adam")
+
+########################################
+
+parser.add_argument("--nb_warmup_iter", type=int, default=100)
+
+parser.add_argument("--nb_decay_iter", type=int, default=5000)
+
+parser.add_argument("--learning_rate", type=float, default=6e-4)
+
+parser.add_argument("--min_learning_rate", type=float, default=6e-5)
+
+########################################
+
+parser.add_argument("--model", type=str, default=None)
+
+parser.add_argument("--attention", type=str, default=None)
+
+parser.add_argument("--dim_model", type=int, default=None)
+
+parser.add_argument("--dim_keys", type=int, default=None)
+
+parser.add_argument("--dim_hidden", type=int, default=None)
+
+parser.add_argument("--nb_heads", type=int, default=None)
+
+parser.add_argument("--nb_lines", type=int, default=None)
+
+parser.add_argument("--caterpillar_height", type=int, default=None)
+
+parser.add_argument("--rho", type=float, default=0.0)
+
+parser.add_argument("--dim_rec_v", type=int, default=None)
+
+parser.add_argument("--nb_blocks", type=int, default=None)
+
+parser.add_argument("--dropout", type=float, default=0.1)
+
+########################################
+
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
+
+parser.add_argument("--overwrite_results", action="store_true", default=False)
+
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
+
+##############################
+# rpl options
+
+parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
+
+parser.add_argument("--rpl_max_input", type=int, default=9)
+
+parser.add_argument("--rpl_prog_len", type=int, default=8)
+
+parser.add_argument("--rpl_nb_runs", type=int, default=5)
+
+parser.add_argument("--rpl_no_prog", action="store_true", default=False)
+
+##############################
+# grid options
+
+parser.add_argument("--grid_size", type=int, default=6)
+
+##############################
+# picoclvr options
+
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
+
+parser.add_argument("--picoclvr_height", type=int, default=12)
+
+parser.add_argument("--picoclvr_width", type=int, default=16)
+
+parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
+
+##############################
+# Maze options
+
+parser.add_argument("--maze_height", type=int, default=13)
+
+parser.add_argument("--maze_width", type=int, default=21)
+
+parser.add_argument("--maze_nb_walls", type=int, default=15)
+
+##############################
+# Snake options
+
+parser.add_argument("--snake_height", type=int, default=9)
+
+parser.add_argument("--snake_width", type=int, default=12)
+
+parser.add_argument("--snake_nb_colors", type=int, default=5)
+
+parser.add_argument("--snake_length", type=int, default=200)
+
+##############################
+# Stack options
+
+parser.add_argument("--stack_nb_steps", type=int, default=100)
+
+parser.add_argument("--stack_nb_stacks", type=int, default=3)
+
+parser.add_argument("--stack_nb_digits", type=int, default=3)
+
+parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
+
+##############################
+# Expr options
+
+parser.add_argument("--expr_nb_variables", type=int, default=5)
+
+parser.add_argument("--expr_sequence_length", type=int, default=40)
+
+parser.add_argument("--expr_operand_max", type=int, default=9)
+
+parser.add_argument("--expr_result_max", type=int, default=99)
+
+parser.add_argument("--expr_input_file", type=str, default=None)
+
+##############################
+# Memory
+
+parser.add_argument("--memory_len_total", type=int, default=32)
+
+##############################
+# Mixing
+
+parser.add_argument("--mixing_hard", action="store_true", default=False)
+
+parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
+
+######################################################################
+
+args = parser.parse_args()
+
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+if args.result_dir is None:
+ args.result_dir = f"results_{args.task}_{args.model}"
+
+######################################################################
+
+default_task_args = {
+ "addition": {
+ "model": "352M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "byheart": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 50000,
+ "nb_test_samples": 10000,
+ },
+ "expr": {
+ "model": "352M",
+ "batch_size": 25,
+ "nb_train_samples": 2500000,
+ "nb_test_samples": 10000,
+ },
+ "grid": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "qmlp": {
+ "model": "37M",
+ "batch_size": 10,
+ "nb_train_samples": 100000,
+ "nb_test_samples": 1000,
+ },
+ "guessop": {
+ "model": "352M",
+ "batch_size": 25,
+ "nb_train_samples": 1000000,
+ "nb_test_samples": 10000,
+ },
+ "learnop": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 50000,
+ "nb_test_samples": 10000,
+ },
+ "maze": {
+ "model": "37M",
+ "batch_size": 5,
+ "nb_train_samples": 100000,
+ "nb_test_samples": 10000,
+ },
+ "picoclvr": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "rpl": {
+ "model": "352M",
+ "batch_size": 5,
+ "nb_train_samples": 2500000,
+ "nb_test_samples": 10000,
+ },
+ "snake": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "stack": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 100000,
+ "nb_test_samples": 1000,
+ },
+ "twotargets": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 50000,
+ "nb_test_samples": 10000,
+ },
+ "memory": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 25000,
+ "nb_test_samples": 10000,
+ },
+ "mixing": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "mnist": {
+ "model": "37M",
+ "batch_size": 10,
+ "nb_train_samples": 60000,
+ "nb_test_samples": 10000,
+ },
+}
+
+if args.task in default_task_args:
+ for k, v in default_task_args[args.task].items():
+ if getattr(args, k) is None:
+ setattr(args, k, v)
+
+######################################################################
+
+default_model_args = {
+ "17K": {
+ "attention": "mha",
+ "dim_model": 32,
+ "dim_keys": 32,
+ "dim_hidden": 32,
+ "nb_heads": 2,
+ "dim_rec_v": 16,
+ "nb_blocks": 2,
+ },
+ "17K-C": {
+ "attention": "caterpillar",
+ "dim_model": 32,
+ "dim_keys": 32,
+ "dim_hidden": 32,
+ "nb_heads": 2,
+ "nb_lines": 16,
+ "caterpillar_height": 4,
+ "dim_rec_v": 16,
+ "nb_blocks": 2,
+ },
+ "4M": {
+ "attention": "mha",
+ "dim_model": 256,
+ "dim_keys": 32,
+ "dim_hidden": 1024,
+ "nb_heads": 4,
+ "dim_rec_v": 64,
+ "nb_blocks": 6,
+ },
+ "4M-C": {
+ "attention": "caterpillar",
+ "dim_model": 256,
+ "dim_keys": 32,
+ "dim_hidden": 1024,
+ "nb_heads": 4,
+ "nb_lines": 32,
+ "caterpillar_height": 4,
+ "dim_rec_v": 64, # dim_model / nb_heads
+ "nb_blocks": 6,
+ },
+ "37M": {
+ "dim_model": 512,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "dim_rec_v": 64,
+ "nb_blocks": 12,
+ },
+ "37M-C": {
+ "attention": "caterpillar",
+ "dim_model": 512,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "nb_lines": 256,
+ "caterpillar_height": 32,
+ "dim_rec_v": 64,
+ "nb_blocks": 12,
+ },
+ "122M": {
+ "attention": "mha",
+ "dim_model": 768,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "dim_rec_v": 96,
+ "nb_blocks": 24,
+ },
+ "122M-C": {
+ "attention": "caterpillar",
+ "dim_model": 768,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "nb_lines": 128,
+ "dim_rec_v": 96,
+ "nb_blocks": 24,
+ },
+ "352M": {
+ "attention": "mha",
+ "dim_model": 1024,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "dim_rec_v": 128,
+ "nb_blocks": 48,
+ },
+ "352M-C": {
+ "attention": "caterpillar",
+ "dim_model": 1024,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "nb_lines": 128,
+ "dim_rec_v": 128,
+ "nb_blocks": 48,
+ },
+}
+
+if args.model in default_model_args:
+ for k, v in default_model_args[args.model].items():
+ if getattr(args, k) is None:
+ setattr(args, k, v)
+else:
+ raise ValueError(f"Unknown model {args.model}")
+
+######################################################################
+
+try:
+ os.mkdir(args.result_dir)
+except FileExistsError:
+ if not args.overwrite_results:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
+
+log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
+
+if args.seed >= 0:
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # torch.use_deterministic_algorithms(True)
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
+######################################################################
+
+
+def log_string(s):
+ t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
+
+ if log_file is not None:
+ log_file.write(t + s + "\n")
+ log_file.flush()
+
+ print(t + s)
+ sys.stdout.flush()
+
+
+with os.popen("sha256sum *.py") as f:
+ for l in f:
+ log_string(f"sha256sum {l.strip()}")
+
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+log_string(f"argv {' '.join(sys.argv)}")
+
+for n in vars(args):
+ log_string(f"args.{n} {getattr(args, n)}")
+
+
+######################################################################
+
+# from nanoGPT
+
+
+def get_lr(it):
+ # 1) linear warmup for warmup_iter steps
+ if it < args.nb_warmup_iter:
+ return args.learning_rate * it / args.nb_warmup_iter
+ # 2) if it > nb_decay_iter, return min learning rate
+ if it > args.nb_decay_iter:
+ return args.min_learning_rate
+ # 3) in between, use cosine decay down to min learning rate
+ decay_ratio = (it - args.nb_warmup_iter) / (
+ args.nb_decay_iter - args.nb_warmup_iter
+ )
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
+ return args.min_learning_rate + coeff * (
+ args.learning_rate - args.min_learning_rate
+ )
+
+
+######################################################################
+
+
+def picoclvr_pruner_horizontal_green(p):
+ return not ("green" in p and ("left" in p or "right" in p))
+
+
+picoclvr_pruner_train = (
+ picoclvr_pruner_horizontal_green
+ if args.picocvlr_prune_properties in {"train+eval"}
+ else None
+)
+
+picoclvr_pruner_eval = (
+ (lambda p: not picoclvr_pruner_horizontal_green(p))
+ if args.picocvlr_prune_properties in {"train+eval", "eval"}
+ else None
+)
+
+######################################################################
+
+device_data = device
+
+if args.task == "byheart":
+ task = tasks.SandBox(
+ problem=problems.ProblemByHeart(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+ args.max_percents_of_test_in_train = -1
+
+elif args.task == "learnop":
+ task = tasks.SandBox(
+ problem=problems.ProblemLearnOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+
+elif args.task == "guessop":
+ task = tasks.SandBox(
+ problem=problems.ProblemGuessOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+
+elif args.task == "twotargets":
+ task = tasks.SandBox(
+ problem=problems.ProblemTwoTargets(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "memory":
+ task = tasks.SandBox(
+ problem=problems.ProblemMemory(len_total=args.memory_len_total),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "mixing":
+ task = tasks.SandBox(
+ problem=problems.ProblemMixing(
+ hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+ ),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "addition":
+ task = tasks.SandBox(
+ problem=problems.ProblemAddition(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "picoclvr":
+ task = tasks.PicoCLVR(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ nb_colors=args.picoclvr_nb_colors,
+ logger=log_string,
+ device=device_data,
+ pruner_train=picoclvr_pruner_train,
+ pruner_eval=picoclvr_pruner_eval,
+ )
+
+elif args.task == "mnist":
+ task = tasks.MNIST(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ device=device_data,
+ )
+
+elif args.task == "maze":
+ task = tasks.Maze(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.maze_height,
+ width=args.maze_width,
+ nb_walls=args.maze_nb_walls,
+ device=device_data,
+ )
+
+elif args.task == "snake":
+ task = tasks.Snake(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.snake_height,
+ width=args.snake_width,
+ nb_colors=args.snake_nb_colors,
+ length=args.snake_length,
+ prompt_length=args.snake_length // 2,
+ device=device_data,
+ )
+
+elif args.task == "stack":
+ task = tasks.Stack(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ nb_steps=args.stack_nb_steps,
+ nb_stacks=args.stack_nb_stacks,
+ nb_digits=args.stack_nb_digits,
+ fraction_values_for_train=args.stack_fraction_values_for_train,
+ device=device_data,
+ )
+
+elif args.task == "expr":
+ task = tasks.Expr(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ nb_variables=args.expr_nb_variables,
+ sequence_length=args.expr_sequence_length,
+ operand_max=args.expr_operand_max,
+ result_max=args.expr_result_max,
+ batch_size=args.batch_size,
+ device=device_data,
+ )
+
+elif args.task == "rpl":
+ task = tasks.RPL(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ nb_starting_values=args.rpl_nb_starting_values,
+ max_input=args.rpl_max_input,
+ prog_len=args.rpl_prog_len,
+ nb_runs=args.rpl_nb_runs,
+ no_prog=args.rpl_no_prog,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "grid":
+ task = tasks.Grid(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ size=args.grid_size,
+ logger=log_string,
+ device=device_data,
+ )
+
+elif args.task == "qmlp":
+ task = tasks.QMLP(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ result_dir=args.result_dir,
+ logger=log_string,
+ device=device_data,
+ )
+
+else:
+ raise ValueError(f"Unknown task {args.task}")
+
+######################################################################
+
+log_string(f"device {device}")
+
+vocabulary_size = task.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+##############################
+
+model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_lines=args.nb_lines,
+ caterpillar_height=args.caterpillar_height,
+ dim_rec_v=args.dim_rec_v,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+ attention_layer=args.attention,
+)
+
+model.to(device)
+
+nb_parameters = sum(p.numel() for p in model.parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+ log_string(f"not trying to load checkpoint.")
+
+else:
+ try:
+ checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+ checkpoint = torch.load(checkpoint_name)
+ nb_epochs_finished = checkpoint["nb_epochs_finished"]
+ model.load_state_dict(checkpoint["model_state"])
+ torch.set_rng_state(checkpoint["rng_state"])
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+
+ log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
+
+ except FileNotFoundError:
+ log_string("starting from scratch.")
+
+ except:
+ log_string("error when loading the checkpoint.")
+ exit(1)
+
+######################################################################
+
+if args.task == "expr" and args.expr_input_file is not None:
+ task.produce_results(
+ n_epoch=nb_epochs_finished,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
+ input_file=args.expr_input_file,
+ )
+
+ exit(0)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in task.batches(split="train"):
+ token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+ def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
+
+ nb_test, nb_in_train = 0, 0
+ for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
+ in_train = set()
+ for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
+
+ log_string(
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+ )
+
+ assert (
+ nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+ ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+##############################
+
+nb_samples_seen = 0
+
+if nb_epochs_finished >= nb_epochs:
+ task.produce_results(
+ n_epoch=nb_epochs_finished,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
+ )
+
+time_pred_result = None
+
+it = 0
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+ if args.optim == "sgd":
+ optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
+ elif args.optim == "adam":
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ elif args.optim == "adamw":
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
+ else:
+ raise ValueError(f"Unknown optimizer {args.optim}.")
+
+ model.train()
+
+ nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
+
+ for input in task.batches(split="train"):
+ model.reset_inner_loss()
+ input = input.to(device)
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ inner_loss = model.get_inner_loss()
+
+ acc_train_loss += loss.item() * input.size(0)
+ acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
+
+ total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+
+ it += 1
+ lr = get_lr(it)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+ # log_string(f"learning_rate {lr}")
+
+ optimizer.zero_grad()
+ total_loss.backward()
+ optimizer.step()
+
+ with torch.autograd.no_grad():
+ model.eval()
+
+ nb_test_samples, acc_test_loss = 0, 0.0
+
+ for input in task.batches(split="test"):
+ input = input.to(device)
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ log_string(
+ f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
+ )
+
+ task.produce_results(
+ n_epoch=n_epoch,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
+ )
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+ log_string(
+ f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+ )
+
+ time_current_result = datetime.datetime.now()
+ if time_pred_result is not None:
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
+ time_pred_result = time_current_result
+
+ checkpoint = {
+ "nb_epochs_finished": n_epoch + 1,
+ "model_state": model.state_dict(),
+ "rng_state": torch.get_rng_state(),
+ }
+
+ if torch.cuda.is_available():
+ checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
+ checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+ torch.save(checkpoint, checkpoint_name)
+ log_string(f"saved checkpoint {checkpoint_name}")
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
+
+
+def create_maze(h=11, w=17, nb_walls=8):
+ assert h % 2 == 1 and w % 2 == 1
+
+ a, k = 0, 0
+
+ while k < nb_walls:
+ while True:
+ if a == 0:
+ m = torch.zeros(h, w, dtype=torch.int64)
+ m[0, :] = 1
+ m[-1, :] = 1
+ m[:, 0] = 1
+ m[:, -1] = 1
+
+ r = torch.rand(4)
+
+ if r[0] <= 0.5:
+ i1, i2, j = (
+ int((r[1] * h).item()),
+ int((r[2] * h).item()),
+ int((r[3] * w).item()),
+ )
+ i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
+ i1, i2 = min(i1, i2), max(i1, i2)
+ if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+ m[i1 : i2 + 1, j] = 1
+ break
+ else:
+ i, j1, j2 = (
+ int((r[1] * h).item()),
+ int((r[2] * w).item()),
+ int((r[3] * w).item()),
+ )
+ i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
+ j1, j2 = min(j1, j2), max(j1, j2)
+ if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+ m[i, j1 : j2 + 1] = 1
+ break
+ a += 1
+
+ if a > 10 * nb_walls:
+ a, k = 0, 0
+
+ k += 1
+
+ return m
+
+
+######################################################################
+
+
+def compute_distance(walls, goal_i, goal_j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[goal_i, goal_j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ d = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
+
+ dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
+
+######################################################################
+
+
+def compute_policy(walls, goal_i, goal_j):
+ distance = compute_distance(walls, goal_i, goal_j)
+ distance = distance + walls.numel() * walls
+
+ value = distance.new_full((4,) + distance.size(), walls.numel())
+ value[0, :, 1:] = distance[:, :-1] # <
+ value[1, :, :-1] = distance[:, 1:] # >
+ value[2, 1:, :] = distance[:-1, :] # ^
+ value[3, :-1, :] = distance[1:, :] # v
+
+ proba = (value.min(dim=0)[0][None] == value).float()
+ proba = proba / proba.sum(dim=0)[None]
+ proba = proba * (1 - walls) + walls.float() / 4
+
+ return proba
+
+
+def stationary_densities(mazes, policies):
+ policies = policies * (mazes != v_goal)[:, None]
+ start = (mazes == v_start).nonzero(as_tuple=True)
+ probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
+ pred_probas = probas.clone()
+ probas[start] = 1.0
+
+ while not pred_probas.equal(probas):
+ pred_probas.copy_(probas)
+ probas.zero_()
+ probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+ probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+ probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+ probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+ probas[start] = 1.0
+
+ return probas
+
+
+######################################################################
+
+
+def mark_path(walls, i, j, goal_i, goal_j, policy):
+ action = torch.distributions.categorical.Categorical(
+ policy.permute(1, 2, 0)
+ ).sample()
+ n, nmax = 0, walls.numel()
+ while i != goal_i or j != goal_j:
+ di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
+ i, j = i + di, j + dj
+ assert walls[i, j] == 0
+ walls[i, j] = v_path
+ n += 1
+ assert n < nmax
+
+
+def path_optimality(ref_paths, paths):
+ return (ref_paths == v_path).long().flatten(1).sum(1) == (
+ paths == v_path
+ ).long().flatten(1).sum(1)
+
+
+def path_correctness(mazes, paths):
+ still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum(
+ 1
+ ) == 0
+ reached = still_ok.new_zeros(still_ok.size())
+ current, pred_current = paths.clone(), paths.new_zeros(paths.size())
+ goal = (mazes == v_goal).long()
+ while not pred_current.equal(current):
+ pred_current.copy_(current)
+ u = (current == v_start).long()
+ possible_next = (
+ u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
+ ).long()
+ u = u[:, 1:-1, 1:-1]
+ reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
+ (current == v_path).sum((1, 2)) == 0
+ )
+ current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
+ v_start - v_path
+ ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
+ still_ok *= (current == v_start).sum((1, 2)) <= 1
+
+ return still_ok * reached
+
+
+######################################################################
+
+
+def create_maze_data(
+ nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
+):
+ mazes = torch.empty(nb, height, width, dtype=torch.int64)
+ paths = torch.empty(nb, height, width, dtype=torch.int64)
+ policies = torch.empty(nb, 4, height, width)
+
+ for n in progress_bar(range(nb)):
+ maze = create_maze(height, width, nb_walls)
+ i = (maze == v_empty).nonzero()
+ while True:
+ start, goal = i[torch.randperm(i.size(0))[:2]]
+ if (start - goal).abs().sum() >= dist_min:
+ break
+ start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
+
+ policy = compute_policy(maze, goal_i, goal_j)
+ path = maze.clone()
+ mark_path(path, start_i, start_j, goal_i, goal_j, policy)
+ maze[start_i, start_j] = v_start
+ maze[goal_i, goal_j] = v_goal
+ path[start_i, start_j] = v_start
+ path[goal_i, goal_j] = v_goal
+
+ mazes[n] = maze
+ paths[n] = path
+ policies[n] = policy
+
+ return mazes, paths, policies
+
+
+######################################################################
+
+
+def save_image(
+ name,
+ mazes,
+ target_paths=None,
+ predicted_paths=None,
+ path_correct=None,
+ path_optimal=None,
+):
+ colors = torch.tensor(
+ [
+ [255, 255, 255], # empty
+ [0, 0, 0], # wall
+ [0, 255, 0], # start
+ [127, 127, 255], # goal
+ [255, 0, 0], # path
+ ]
+ )
+
+ mazes = mazes.cpu()
+
+ c_mazes = (
+ colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+ )
+
+ imgs = c_mazes.unsqueeze(1)
+
+ if target_paths is not None:
+ target_paths = target_paths.cpu()
+
+ c_target_paths = (
+ colors[target_paths.reshape(-1)]
+ .reshape(target_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+
+ imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
+
+ if predicted_paths is not None:
+ predicted_paths = predicted_paths.cpu()
+ c_predicted_paths = (
+ colors[predicted_paths.reshape(-1)]
+ .reshape(predicted_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+ imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+ img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1)
+
+ # NxKxCxHxW
+ if path_optimal is not None:
+ path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1)
+ img = (
+ img * (1 - path_optimal)
+ + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal
+ )
+
+ if path_correct is not None:
+ path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+ img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * (
+ 1 - path_correct
+ )
+
+ img = img.expand(
+ -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+ ).clone()
+
+ print(f"{img.size()=} {imgs.size()=}")
+
+ for k in range(imgs.size(1)):
+ img[
+ :,
+ :,
+ 1 : 1 + imgs.size(3),
+ 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+ ] = imgs[:, k]
+
+ img = img.float() / 255.0
+
+ torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
+
+
+######################################################################
+
+if __name__ == "__main__":
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ mazes, paths, policies = create_maze_data(8)
+ mazes, paths = mazes.to(device), paths.to(device)
+ save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
+ print(path_correctness(mazes, paths))
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+import torch
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CppExtension
+
+cpp_source = """
+std::vector<torch::Tensor> greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) {
+ auto nb_lines = load_start.size(1);
+ auto batch_size = line_requests.size(0);
+ auto nb_heads = line_requests.size(1);
+ auto T = line_requests.size(2);
+
+ auto load_start_a = load_start.accessor<float,2>();
+ auto line_requests_a = line_requests.accessor<float,3>();
+
+ auto load = torch::empty({batch_size, nb_lines, T});
+ auto load_a = load.accessor<float,3>();
+
+ auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64));
+ auto allocation_result_a = allocation_result.accessor<long,3>();
+
+ for(int n = 0; n < batch_size; n++) {
+ for(int t = 0; t < T; t++) {
+ for(int l = 0; l < nb_lines; l++) {
+ if(t == 0) {
+ load[n][l][t] = decay * load_start_a[n][l];
+ } else {
+ load[n][l][t] = decay * load[n][l][t-1];
+ }
+ }
+ for(int h = 0; h < nb_heads; h++) {
+ if(line_requests_a[n][h][t] > 0) {
+ int l_lowest_load;
+ for(int l = 0; l < nb_lines; l++) {
+ if(l == 0 || load_a[n][l][t]<load_a[n][l_lowest_load][t]) l_lowest_load=l;
+ }
+ if(load_a[n][l_lowest_load][t] < line_requests_a[n][h][t]) {
+ allocation_result_a[n][h][t] = l_lowest_load;
+ load_a[n][l_lowest_load][t] = line_requests_a[n][h][t];
+ } else {
+ allocation_result_a[n][h][t] = -1;
+ }
+ } else {
+ allocation_result_a[n][h][t] = -1;
+ }
+ }
+ }
+ }
+
+ return {allocation_result,load};
+}
+"""
+
+######################################################################
+
+allocator_module = torch.utils.cpp_extension.load_inline(
+ name="allocator_module",
+ cpp_sources=[cpp_source],
+ functions=["greedy_lines_allocation"],
+ build_directory="/tmp/",
+ # verbose=True,
+)
+
+lines_allocation = allocator_module.greedy_lines_allocation
+
+######################################################################
+
+if __name__ == "__main__":
+ N, H, L, T = 1, 1, 3, 20
+
+ load_start = torch.rand(N, L)
+ requests = (2 * torch.rand(N, H, T) - 1).clamp(min=0)
+
+ print("load_start", load_start)
+
+ print("requests", requests)
+
+ alloc, load = lines_allocation(load_start, 0.99, requests)
+
+ print("alloc", alloc)
+
+ print("load", load)
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
+import math, warnings
+
+import torch, einops
+
+from torch import nn
+from torch.nn import functional as F
+from functorch.dim import dims
+
+import ffutils
+
+# import memload
+
+######################################################################
+
+# A BracketedSequence is a BxTx... tensor with a first and a nb time
+# steps to compute.
+
+# Modules able to process it expect that they will have to process a
+# first bracket starting at t=0, followed by a succession of brackets
+# that move forward in time, do not overlap, and cover the axis T with
+# no holes.
+#
+# Although it is more general, for a classical prompt-conditioned
+# auto-regressive process it will be a first bracket starting at 0 and
+# of arbitrary length for the "prompt", followed by brackets of length
+# 1 for the successive tokens.
+#
+# Modules able to process brackets may implement a cache that is
+# resetted when the input bracket starts at t=0
+
+
+class BracketedSequence:
+ def __init__(self, x, first=None, nb=None, init_cache=None):
+ self.x = x
+ assert (first is None and nb is None and init_cache is None) or (
+ first is not None and nb is not None and init_cache is not None
+ )
+
+ self.first = 0 if first is None else first
+ self.nb = x.size(1) if nb is None else nb
+ self.init_cache = True if init_cache is None else init_cache
+
+ def slice(self):
+ return self.x[:, self.first : self.first + self.nb]
+
+ def complete(self):
+ return self.first == 0 and self.nb == self.x.size(1)
+
+
+######################################################################
+
+
+class CacheWrapper(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ if bs.init_cache:
+ y = self.f(bs.slice())
+ self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
+ self.cache_y[:, bs.first : bs.first + bs.nb] = y
+ else:
+ assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
+ assert bs.first + bs.nb <= self.cache_y.size(1)
+ self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
+
+ return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class AddPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+ def forward(self, bs):
+ if bs.init_cache:
+ t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+ :, None
+ ]
+ j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ None, :
+ ]
+ k = j % 2
+ self.pe = torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ self.cache_y = bs.x.new(bs.x.size())
+
+ self.cache_y[:, bs.first : bs.first + bs.nb] = (
+ bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ )
+
+ return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+import pscan
+
+
+# X is /.../xTxD A is /.../xT Y_init is /.../xD
+
+
+def pscan_dim(A, X, Y_init, dim=-2):
+ s = X.size()
+ a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
+
+ A = A.reshape(a, T, *s[dim + 1 : -1])
+ X = X.reshape(a, T, *s[dim + 1 : -1], -1)
+
+ if Y_init is None:
+ Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
+ else:
+ Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
+
+ Y = pscan.pscan(A, X, Y_init).reshape(s)
+
+ return Y
+
+
+def pscan_shape(A, X, Y_init):
+ s = X.size()
+ A = A.reshape(-1, s[-2])
+ X = X.reshape(-1, s[-2], s[-1])
+
+ if Y_init is None:
+ Y_init = X.new_zeros(X.size(0), s[-1])
+ else:
+ Y_init = Y_init.reshape(-1, s[-1])
+
+ Y = pscan.pscan(A, X, Y_init).reshape(s)
+
+ return Y
+
+
+def nsum_shape(X, Y_init):
+ s = X.size()
+ X = X.reshape(-1, s[-2], s[-1]) # ntd
+
+ Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
+ result = []
+
+ for k in range(X.size(1)):
+ Y = Y + X[:, k]
+ Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
+ result.append(Y)
+
+ return torch.cat(result, dim=1).reshape(s)
+
+
+##############################
+
+
+class DumbRec(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads,
+ nb_lines,
+ attention_dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.nb_lines = nb_lines
+ self.attention_dropout = attention_dropout
+
+ self.k_star = randw(nb_lines, dim_qk)
+
+ self.w_qw = randw(nb_heads, dim_qk, dim_in)
+ self.w_qr = randw(nb_heads, dim_qk, dim_in)
+ # self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(dim_v * nb_heads, dim_in)
+
+ def reset_inner_loss(self):
+ self.acc_attention = 0
+ self.acc_nb = 0
+
+ def get_inner_loss(self):
+ warnings.warn("l2 regularization", RuntimeWarning)
+ return (self.acc_attention / self.acc_nb).pow(2).sum()
+ # return torch.tensor([0], device=self.w_qw.device)
+
+ def forward(self, bs):
+ x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
+
+ if bs.init_cache:
+ self.rec_v = x_q.new_zeros(
+ x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
+ )
+ # self.rec_k = x_q.new_zeros(
+ # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
+ # )
+ self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+ ######################################################################
+ # Prepare the keys
+
+ k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
+
+ warnings.warn("rotating key barrel", RuntimeWarning)
+ k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
+ t_barrel = torch.arange(t0, t1, device=k_star.device)
+ t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
+ l_barrel = (
+ torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+ ) % k_star.size(0)
+ k_star = k_star[l_barrel, t_barrel]
+
+ ######################################################################
+ # Compute the recurrent state
+
+ qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
+
+ v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
+ # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
+
+ aw = torch.einsum(
+ "nhtd,ltd->nhlt",
+ qw,
+ k_star,
+ ) / math.sqrt(self.w_qw.size(1))
+
+ aw = aw.softmax(dim=2) # nhlt
+
+ if self.train:
+ self.acc_attention += aw.sum(dim=(0, 1, 3))
+ self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
+
+ aw = F.dropout(aw, self.attention_dropout, self.training)
+
+ A = 1 - aw.sum(dim=1) # nlt
+
+ V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
+ # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
+
+ if t0 == 0:
+ V0 = None
+ # K0 = None
+ else:
+ V0 = self.rec_v[:, :, t0 - 1]
+ # K0 = self.rec_k[:, :, t0 - 1]
+
+ self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
+ # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
+
+ ######################################################################
+ # compute the readout
+
+ qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
+
+ ar = torch.einsum(
+ "nhtd,ld->nhlt",
+ qr,
+ # self.rec_k[:, :, t0:t1],
+ self.k_star,
+ ) / math.sqrt(self.w_qr.size(1))
+
+ ar = ar.softmax(dim=2) # nhlt
+
+ ar = F.dropout(ar, self.attention_dropout, self.training)
+
+ y = torch.einsum(
+ "nhlt,nltd->nthd",
+ ar,
+ self.rec_v[:, :, t0:t1],
+ ).flatten(2)
+
+ self.cache_y[:, t0:t1] = y @ self.w_o
+
+ return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+class KVRec(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads,
+ nb_lines,
+ attention_dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.nb_lines = nb_lines
+ self.attention_dropout = attention_dropout
+
+ self.k_star = randw(nb_lines, dim_qk)
+
+ self.w_qw = randw(nb_heads, dim_qk, dim_in)
+ self.w_qr = randw(nb_heads, dim_qk, dim_in)
+ self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(dim_v * nb_heads, dim_in)
+
+ def reset_inner_loss(self):
+ self.acc_attention = 0
+ self.acc_nb = 0
+
+ def get_inner_loss(self):
+ warnings.warn("l2 regularization", RuntimeWarning)
+ return (self.acc_attention / self.acc_nb).pow(2).sum()
+ # return torch.tensor([0], device=self.w_qw.device)
+ # warnings.warn("side regularization", RuntimeWarning)
+ # return (
+ # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
+ # )
+ # return torch.tensor([0], device=self.w_qw.device)
+
+ def forward(self, bs):
+ x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
+
+ # n,h,l,t,d = dims(5)
+
+ if bs.init_cache:
+ self.rec_v = x_q.new_zeros(
+ x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
+ )
+ self.rec_k = x_q.new_zeros(
+ x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
+ )
+ self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+ ######################################################################
+ # Prepare the keys
+
+ k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
+
+ warnings.warn("rotating key barrel", RuntimeWarning)
+ k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
+ t_barrel = torch.arange(t0, t1, device=k_star.device)
+ t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
+ l_barrel = (
+ torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+ ) % k_star.size(0)
+ k_star = k_star[l_barrel, t_barrel]
+
+ ######################################################################
+ # Compute the recurrent state
+
+ qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
+
+ v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
+ k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
+
+ aw = torch.einsum(
+ "nhtd,ltd->nhlt",
+ qw,
+ k_star,
+ ) / math.sqrt(self.w_qw.size(1))
+
+ aw = aw.softmax(dim=2) # nhlt
+
+ if self.train:
+ # We want all the memory lines to be used similarly
+ self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
+ self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
+
+ aw = F.dropout(aw, self.attention_dropout, self.training)
+
+ A = 1 - aw.sum(dim=1) # nlt
+
+ V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
+ K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
+
+ if t0 == 0:
+ V0 = None
+ K0 = None
+ else:
+ V0 = self.rec_v[:, :, t0 - 1]
+ K0 = self.rec_k[:, :, t0 - 1]
+
+ self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
+ self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
+
+ ######################################################################
+ # compute the readout
+
+ qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
+
+ ar = torch.einsum(
+ "nhtd,nltd->nhlt",
+ qr,
+ self.rec_k[:, :, t0:t1],
+ ) / math.sqrt(self.w_qr.size(1))
+
+ ar = ar.softmax(dim=2) # nhlt
+
+ ar = F.dropout(ar, self.attention_dropout, self.training)
+
+ y = torch.einsum(
+ "nhlt,nltd->nthd",
+ ar,
+ self.rec_v[:, :, t0:t1],
+ ).flatten(2)
+
+ self.cache_y[:, t0:t1] = y @ self.w_o
+
+ return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+def moving_window(x, dim, win_dim, win_size):
+ size, stride = x.size(), x.stride()
+ size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
+ size = size[:win_dim] + (win_size,) + size[win_dim:]
+ stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
+
+ return x.as_strided(size=size, stride=stride)
+
+
+##############################
+
+
+class Caterpillar(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads,
+ caterpillar_length,
+ caterpillar_height,
+ attention_dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ warnings.warn("Caterpillar", RuntimeWarning)
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.caterpillar_length = caterpillar_length
+ self.caterpillar_height = caterpillar_height
+ self.attention_dropout = attention_dropout
+
+ self.w_G = randw(nb_heads, caterpillar_height, dim_in)
+ self.b_G = nn.Parameter(
+ torch.full(
+ (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
+ )
+ )
+
+ self.w_K = randw(nb_heads, dim_qk, dim_in)
+ self.w_V = randw(nb_heads, dim_v, dim_in)
+ self.w_Q = randw(nb_heads, dim_qk, dim_in)
+ self.w_O = randw(dim_v * nb_heads, dim_in)
+
+ self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
+ self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
+
+ def reset_inner_loss(self):
+ self.acc_attention = 0
+ self.acc_nb = 0
+
+ def get_inner_loss(self):
+ # warnings.warn("l2 regularization", RuntimeWarning)
+ # return (self.acc_attention / self.acc_nb).pow(2).sum()
+ return torch.tensor([0], device=self.w_Q.device)
+
+ def forward(self, bs):
+ # Dimensions to make the source a bit clearer, that's needed
+
+ X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
+
+ N = bs.x.size(0)
+ T = bs.x.size(1)
+ DV = self.w_V.size(1)
+ DK = self.w_K.size(1)
+ Dout = self.w_O.size(1)
+ CH = self.caterpillar_height
+ CL = self.caterpillar_length
+
+ assert (
+ t0 >= CL and (t1 - t0) % CL == 0
+ ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
+
+ if bs.init_cache:
+ self.rec_V = X.new_zeros(N, CH, T, DV)
+ self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
+ self.rec_K = X.new_zeros(N, CH, T, DK)
+ self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
+ self.cache_Y = X.new_zeros(N, T, Dout)
+
+ ######################################################################
+ # Compute the recurrent state
+
+ G = (
+ torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
+ ).sigmoid()
+
+ V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
+ K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+
+ A = 1 - G.sum(1)
+ gated_V = torch.einsum("nhet,nhtd->netd", G, V)
+ gated_K = torch.einsum("nhet,nhtd->netd", G, K)
+
+ init_rec_V = self.rec_V[:, :, t0 - CL : t0]
+ init_rec_K = self.rec_K[:, :, t0 - CL : t0]
+
+ A = A.unflatten(2, (-1, CL))
+ gated_V = gated_V.unflatten(2, (-1, CL))
+ gated_K = gated_K.unflatten(2, (-1, CL))
+
+ next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+ next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
+
+ self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
+ self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
+
+ ######################################################################
+ # compute the readout
+
+ Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
+
+ uv = moving_window(
+ self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+ )
+
+ uk = moving_window(
+ self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+ )
+
+ ar = torch.einsum(
+ "nhtd,nftld->nhtfl",
+ Q,
+ uk,
+ ) / math.sqrt(DK)
+
+ ar = ar.flatten(3).softmax(dim=3).view(ar.size())
+
+ ar = F.dropout(ar, self.attention_dropout, self.training)
+
+ Y = torch.einsum(
+ "nhtfl,nftld->nthd",
+ ar,
+ uv,
+ ).flatten(2)
+
+ self.cache_Y[:, t0:t1] = Y @ self.w_O
+
+ return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+class QKVAttention(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.causal = causal
+ self.attention_dropout = attention_dropout
+ self.record_attention = False
+
+ self.w_q = randw(nb_heads, dim_qk, dim_in)
+ self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(dim_v * nb_heads, dim_in)
+
+ def forward(self, bs):
+ x_q = bs.x
+
+ assert (
+ self.causal or bs.complete()
+ ), "Partial evaluation is only possible for causal models"
+
+ if bs.init_cache:
+ self.cache_k = x_q.new_zeros(
+ x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
+ )
+ self.cache_v = x_q.new_zeros(
+ x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
+ )
+ self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
+
+ self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
+ )
+ self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
+ )
+
+ a = torch.einsum(
+ "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
+ ) / math.sqrt(self.w_q.size(1))
+
+ if self.causal:
+ if bs.init_cache:
+ self.cache_attzero = (
+ torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+ < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
+ )
+ a = a.masked_fill(
+ self.cache_attzero[
+ :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
+ ],
+ float("-inf"),
+ )
+
+ a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
+ a = F.dropout(a, self.attention_dropout, self.training)
+
+ y = torch.einsum(
+ "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
+ ).flatten(2)
+
+ self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
+
+ return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class MyGPT(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ nb_lines=None,
+ caterpillar_height=None,
+ dim_rec_v=-1,
+ causal=False,
+ dropout=0.0,
+ len_max=1e5,
+ attention_layer="kvrec",
+ ):
+ super().__init__()
+
+ assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
+
+ if attention_layer == "caterpillar":
+ assert nb_lines % caterpillar_height == 0
+ self.caterpillar_length = nb_lines // caterpillar_height
+ self.caterpillar_height = caterpillar_height
+ else:
+ self.caterpillar_length = -1
+ self.caterpillar_height = -1
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+ AddPositionalEncoding(len_max),
+ )
+
+ trunk_blocks = []
+
+ def attlayer():
+ if attention_layer == "mha":
+ return QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ causal=causal,
+ attention_dropout=dropout,
+ )
+ elif attention_layer == "dumbrec":
+ return DumbRec(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_rec_v,
+ nb_heads=nb_heads,
+ nb_lines=nb_lines,
+ attention_dropout=dropout,
+ )
+ elif attention_layer == "kvrec":
+ return KVRec(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_rec_v,
+ nb_heads=nb_heads,
+ nb_lines=nb_lines,
+ attention_dropout=dropout,
+ )
+ elif attention_layer == "caterpillar":
+ return Caterpillar(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_rec_v,
+ nb_heads=nb_heads,
+ caterpillar_length=self.caterpillar_length,
+ caterpillar_height=self.caterpillar_height,
+ attention_dropout=dropout,
+ )
+ else:
+ raise ValueError(f"Unknown attention type {attention_layer}.")
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ attlayer(),
+ ),
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = CacheWrapper(
+ nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ )
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ self.reset_inner_loss()
+
+ def forward(self, bs):
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
+
+ # To make the code simpler in the Caterpillar layer, we pad
+ # here. It's unclear if/how much it hurts computationaly by
+ # increasing the sequence length for the other layers
+
+ if self.caterpillar_length > 0:
+ original_nb = bs.nb
+ if bs.nb % self.caterpillar_length > 0:
+ bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
+
+ bs = BracketedSequence(
+ F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
+ bs.first + self.caterpillar_length,
+ bs.nb,
+ bs.init_cache,
+ )
+
+ bs = self.embedding(bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+
+ if self.caterpillar_length > 0:
+ bs = BracketedSequence(
+ F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
+ bs.first - self.caterpillar_length,
+ original_nb,
+ bs.init_cache,
+ )
+
+ return bs
+
+ # ar_mask is a tensor with 0s and 1s, of same shape as input, with
+ # 1s where tokens should be generated. The others are kept
+ # unchanged.
+
+ def masked_inplace_autoregression(
+ self,
+ input_src,
+ ar_mask_src,
+ forbidden_tokens=None,
+ deterministic_synthesis=False,
+ ):
+ input = input_src.to(self.readout.f.weight.device)
+ ar_mask = ar_mask_src.to(self.readout.f.weight.device)
+ to_generate = (ar_mask.sum(0) > 0).nonzero()
+ if to_generate.min() > 0:
+ self(
+ BracketedSequence(input, 0, to_generate.min(), True)
+ ) # Needed to initialize the model's cache
+ for s in range(to_generate.min(), to_generate.max() + 1):
+ output = self(BracketedSequence(input, s, 1, s == 0)).x
+ logits = output[:, s]
+ if forbidden_tokens is not None:
+ logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+ if deterministic_synthesis:
+ t_next = logits.argmax(1)
+ else:
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ t_next = dist.sample()
+ input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+ input_src.copy_(input)
+
+ def reset_inner_loss(self):
+ for m in self.modules():
+ if m is not self and hasattr(m, "reset_inner_loss"):
+ m.reset_inner_loss()
+
+ def get_inner_loss(self):
+ l = torch.tensor([0.0], device=self.readout.f.weight.device)
+ for m in self.modules():
+ if m is not self and hasattr(m, "get_inner_loss"):
+ l += m.get_inner_loss()
+ return l
+
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
+
+######################################################################
+
+if __name__ == "__main__":
+ print("Basic check.")
+
+ m = Caterpillar(
+ dim_in=4,
+ dim_qk=3,
+ dim_v=7,
+ nb_heads=1,
+ caterpillar_length=7,
+ caterpillar_height=3,
+ attention_dropout=0.0,
+ )
+
+ m.reset_inner_loss()
+ x = torch.randn(1, 21 + 2 * 7, 4)
+ y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
+ y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
+ y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
+ y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
+ print((y1 - y2).abs().max())
+ print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
+ exit(0)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ vocabulary_size = 128
+ x = torch.randint(vocabulary_size, (6, 1024))
+
+ model = MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=512,
+ dim_keys=64,
+ dim_hidden=2048,
+ nb_heads=8,
+ nb_lines=128,
+ nb_blocks=12,
+ dropout=0.1,
+ causal=True,
+ )
+
+ x = x.to(device)
+ model.to(device)
+
+ import time, sys
+
+ # import torchvision.models as models
+ # from torch.profiler import profile, record_function, ProfilerActivity
+
+ # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
+ # with record_function("model_inference"):
+
+ model.eval()
+ for i in range(3):
+ start_time = time.perf_counter()
+ for k in range(10):
+ model(BracketedSequence(x))
+ duration = time.perf_counter() - start_time
+ print(duration)
+ sys.stdout.flush()
+
+ # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
+ # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ # print("##############################################################")
+ # y2 = torch.randn_like(y1)
+ # for s in range(x.size(1)):
+ # z = model(BracketedSequence(x, s, 1))
+ # y2[:, s : s + 1] = z.slice()
+
+ # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+import torch, torchvision
+import torch.nn.functional as F
+
+color_name2rgb = {
+ "white": [255, 255, 255],
+ "red": [255, 0, 0],
+ "green": [0, 128, 0],
+ "blue": [0, 0, 255],
+ "yellow": [255, 255, 0],
+ "black": [0, 0, 0],
+ "maroon": [128, 0, 0],
+ "dark_red": [139, 0, 0],
+ "brown": [165, 42, 42],
+ "firebrick": [178, 34, 34],
+ "crimson": [220, 20, 60],
+ "tomato": [255, 99, 71],
+ "coral": [255, 127, 80],
+ "indian_red": [205, 92, 92],
+ "light_coral": [240, 128, 128],
+ "dark_salmon": [233, 150, 122],
+ "salmon": [250, 128, 114],
+ "light_salmon": [255, 160, 122],
+ "orange_red": [255, 69, 0],
+ "dark_orange": [255, 140, 0],
+ "orange": [255, 165, 0],
+ "gold": [255, 215, 0],
+ "dark_golden_rod": [184, 134, 11],
+ "golden_rod": [218, 165, 32],
+ "pale_golden_rod": [238, 232, 170],
+ "dark_khaki": [189, 183, 107],
+ "khaki": [240, 230, 140],
+ "olive": [128, 128, 0],
+ "yellow_green": [154, 205, 50],
+ "dark_olive_green": [85, 107, 47],
+ "olive_drab": [107, 142, 35],
+ "lawn_green": [124, 252, 0],
+ "chartreuse": [127, 255, 0],
+ "green_yellow": [173, 255, 47],
+ "dark_green": [0, 100, 0],
+ "forest_green": [34, 139, 34],
+ "lime": [0, 255, 0],
+ "lime_green": [50, 205, 50],
+ "light_green": [144, 238, 144],
+ "pale_green": [152, 251, 152],
+ "dark_sea_green": [143, 188, 143],
+ "medium_spring_green": [0, 250, 154],
+ "spring_green": [0, 255, 127],
+ "sea_green": [46, 139, 87],
+ "medium_aqua_marine": [102, 205, 170],
+ "medium_sea_green": [60, 179, 113],
+ "light_sea_green": [32, 178, 170],
+ "dark_slate_gray": [47, 79, 79],
+ "teal": [0, 128, 128],
+ "dark_cyan": [0, 139, 139],
+ "aqua": [0, 255, 255],
+ "cyan": [0, 255, 255],
+ "light_cyan": [224, 255, 255],
+ "dark_turquoise": [0, 206, 209],
+ "turquoise": [64, 224, 208],
+ "medium_turquoise": [72, 209, 204],
+ "pale_turquoise": [175, 238, 238],
+ "aqua_marine": [127, 255, 212],
+ "powder_blue": [176, 224, 230],
+ "cadet_blue": [95, 158, 160],
+ "steel_blue": [70, 130, 180],
+ "corn_flower_blue": [100, 149, 237],
+ "deep_sky_blue": [0, 191, 255],
+ "dodger_blue": [30, 144, 255],
+ "light_blue": [173, 216, 230],
+ "sky_blue": [135, 206, 235],
+ "light_sky_blue": [135, 206, 250],
+ "midnight_blue": [25, 25, 112],
+ "navy": [0, 0, 128],
+ "dark_blue": [0, 0, 139],
+ "medium_blue": [0, 0, 205],
+ "royal_blue": [65, 105, 225],
+ "blue_violet": [138, 43, 226],
+ "indigo": [75, 0, 130],
+ "dark_slate_blue": [72, 61, 139],
+ "slate_blue": [106, 90, 205],
+ "medium_slate_blue": [123, 104, 238],
+ "medium_purple": [147, 112, 219],
+ "dark_magenta": [139, 0, 139],
+ "dark_violet": [148, 0, 211],
+ "dark_orchid": [153, 50, 204],
+ "medium_orchid": [186, 85, 211],
+ "purple": [128, 0, 128],
+ "thistle": [216, 191, 216],
+ "plum": [221, 160, 221],
+ "violet": [238, 130, 238],
+ "magenta": [255, 0, 255],
+ "orchid": [218, 112, 214],
+ "medium_violet_red": [199, 21, 133],
+ "pale_violet_red": [219, 112, 147],
+ "deep_pink": [255, 20, 147],
+ "hot_pink": [255, 105, 180],
+ "light_pink": [255, 182, 193],
+ "pink": [255, 192, 203],
+ "antique_white": [250, 235, 215],
+ "beige": [245, 245, 220],
+ "bisque": [255, 228, 196],
+ "blanched_almond": [255, 235, 205],
+ "wheat": [245, 222, 179],
+ "corn_silk": [255, 248, 220],
+ "lemon_chiffon": [255, 250, 205],
+ "light_golden_rod_yellow": [250, 250, 210],
+ "light_yellow": [255, 255, 224],
+ "saddle_brown": [139, 69, 19],
+ "sienna": [160, 82, 45],
+ "chocolate": [210, 105, 30],
+ "peru": [205, 133, 63],
+ "sandy_brown": [244, 164, 96],
+ "burly_wood": [222, 184, 135],
+ "tan": [210, 180, 140],
+ "rosy_brown": [188, 143, 143],
+ "moccasin": [255, 228, 181],
+ "navajo_white": [255, 222, 173],
+ "peach_puff": [255, 218, 185],
+ "misty_rose": [255, 228, 225],
+ "lavender_blush": [255, 240, 245],
+ "linen": [250, 240, 230],
+ "old_lace": [253, 245, 230],
+ "papaya_whip": [255, 239, 213],
+ "sea_shell": [255, 245, 238],
+ "mint_cream": [245, 255, 250],
+ "slate_gray": [112, 128, 144],
+ "light_slate_gray": [119, 136, 153],
+ "light_steel_blue": [176, 196, 222],
+ "lavender": [230, 230, 250],
+ "floral_white": [255, 250, 240],
+ "alice_blue": [240, 248, 255],
+ "ghost_white": [248, 248, 255],
+ "honeydew": [240, 255, 240],
+ "ivory": [255, 255, 240],
+ "azure": [240, 255, 255],
+ "snow": [255, 250, 250],
+ "silver": [192, 192, 192],
+ "gainsboro": [220, 220, 220],
+ "white_smoke": [245, 245, 245],
+}
+
+color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
+color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
+
+######################################################################
+
+
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+ s = []
+
+ for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
+ s += [f"there is {c_r}"]
+
+ if square_i[r] >= height - height // 3:
+ s += [f"{c_r} bottom"]
+ if square_i[r] < height // 3:
+ s += [f"{c_r} top"]
+ if square_j[r] >= width - width // 3:
+ s += [f"{c_r} right"]
+ if square_j[r] < width // 3:
+ s += [f"{c_r} left"]
+
+ for t, c_t in [
+ (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
+ ]:
+ if square_i[r] > square_i[t]:
+ s += [f"{c_r} below {c_t}"]
+ if square_i[r] < square_i[t]:
+ s += [f"{c_r} above {c_t}"]
+ if square_j[r] > square_j[t]:
+ s += [f"{c_r} right of {c_t}"]
+ if square_j[r] < square_j[t]:
+ s += [f"{c_r} left of {c_t}"]
+
+ return s
+
+
+######################################################################
+
+# Generates sequences
+
+
+def generate(
+ nb,
+ height,
+ width,
+ max_nb_squares=5,
+ max_nb_properties=10,
+ nb_colors=5,
+ pruner=None,
+):
+ assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
+
+ descr = []
+
+ for n in range(nb):
+ # we want uniform over the combinations of 1 to max_nb_squares
+ # pixels of nb_colors
+ logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ nb_squares = dist.sample((1,)) + 1
+ # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+ square_position = torch.randperm(height * width)[:nb_squares]
+
+ # color 0 is white and reserved for the background
+ square_c = torch.randperm(nb_colors)[:nb_squares] + 1
+ square_i = square_position.div(width, rounding_mode="floor")
+ square_j = square_position % width
+
+ img = torch.zeros(height * width, dtype=torch.int64)
+ for k in range(nb_squares):
+ img[square_position[k]] = square_c[k]
+
+ # generates all the true properties
+
+ s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
+
+ if pruner is not None:
+ s = list(filter(pruner, s))
+
+ # picks at most max_nb_properties at random
+
+ nb_properties = torch.randint(max_nb_properties, (1,)) + 1
+ s = (
+ " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
+ + " <img> "
+ + " ".join([f"{color_id2name[n.item()]}" for n in img])
+ )
+
+ descr += [s]
+
+ return descr
+
+
+######################################################################
+
+# Extracts the image after <img> in descr as a 1x3xHxW tensor
+
+
+def descr2img(descr, height, width):
+ result = []
+
+ def token2color(t):
+ try:
+ return color_name2rgb[t]
+ except KeyError:
+ return [128, 128, 128]
+
+ for d in descr:
+ d = d.split("<img>")[1]
+ d = d.strip().split(" ")[: height * width]
+ d = d + ["<unk>"] * (height * width - len(d))
+ d = [token2color(t) for t in d]
+ img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+ result.append(img)
+
+ return torch.cat(result, 0)
+
+
+######################################################################
+
+# Returns all the properties of the image after <img> in descr
+
+
+def descr2properties(descr, height, width):
+ if type(descr) == list:
+ return [descr2properties(d, height, width) for d in descr]
+
+ d = descr.split("<img>")
+ img_tokens = d[-1] if len(d) > 1 else ""
+ img_tokens = img_tokens.strip().split(" ")[: height * width]
+ if len(img_tokens) != height * width:
+ return []
+
+ seen = {}
+ for k, x in enumerate(img_tokens):
+ if x != color_id2name[0]:
+ if x in color_name2rgb:
+ if x in seen:
+ return []
+ else:
+ return []
+ seen[x] = (color_name2id[x], k // width, k % width)
+
+ square_infos = tuple(zip(*seen.values()))
+
+ if square_infos:
+ square_c = torch.tensor(square_infos[0])
+ square_i = torch.tensor(square_infos[1])
+ square_j = torch.tensor(square_infos[2])
+ else:
+ square_c = torch.tensor([])
+ square_i = torch.tensor([])
+ square_j = torch.tensor([])
+
+ s = all_properties(height, width, len(seen), square_i, square_j, square_c)
+
+ return s
+
+
+######################################################################
+
+# Returns a triplet composed of (1) the total number of properties
+# before <img> in descr, (2) the total number of properties the image
+# after <img> verifies, and (3) the number of properties in (1) not in
+# (2)
+
+
+def nb_properties(descr, height, width, pruner=None):
+ if type(descr) == list:
+ return [nb_properties(d, height, width, pruner) for d in descr]
+
+ d = descr.split("<img>", 1)
+ if len(d) == 0:
+ return 0
+ d = d[0].strip().split("<sep>")
+ d = [x.strip() for x in d]
+
+ all_properties = set(descr2properties(descr, height, width))
+
+ if pruner is None:
+ requested_properties = set(d)
+ else:
+ requested_properties = set(filter(pruner, d))
+
+ missing_properties = requested_properties - all_properties
+
+ return (len(requested_properties), len(all_properties), len(missing_properties))
+
+
+######################################################################
+
+if __name__ == "__main__":
+ for n in range(16):
+ descr = generate(nb=1, height=12, width=16)
+
+ print(nb_properties(descr, height=12, width=16))
+
+ with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
+ for d in descr:
+ f.write(f"{d}\n\n")
+
+ img = descr2img(descr, height=12, width=16)
+ if img.size(0) == 1:
+ img = F.pad(img, (1, 1, 1, 1), value=64)
+
+ torchvision.utils.save_image(
+ img / 255.0,
+ f"picoclvr_example_{n:02d}.png",
+ padding=1,
+ nrow=4,
+ pad_value=0.8,
+ )
+
+ import time
+
+ start_time = time.perf_counter()
+ descr = generate(nb=1000, height=12, width=16)
+ end_time = time.perf_counter()
+ print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class Problem:
+ def generate_sequences(self, nb):
+ pass
+
+ def seq2str(self, seq):
+ return "[NOT IMPLEMENTED]"
+
+ def compute_nb_correct(self, input, ar_mask, result):
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((result == input).long() * ar_mask).sum().item()
+ return nb_total, nb_correct
+
+
+####################
+
+
+class ProblemDegradation(Problem):
+ def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
+ assert value_max // nb_state_tokens >= 2
+ self.nb_state_tokens = nb_state_tokens
+ self.nb_time_steps = nb_time_steps
+ self.value_max = value_max
+ self.hard = hard
+
+ def generate_sequences(self, nb):
+ x = (
+ torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
+ ).long() * self.value_max
+ seq = [x]
+
+ for t in range(self.nb_time_steps - 1):
+ v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
+ u = (v.max(dim=-1, keepdim=True).values == v).long()
+ n = (
+ (u * x)
+ .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
+ .sum(dim=-1, keepdim=True)
+ )
+ m = 1 + ((n - 1) * torch.rand(n.size())).long()
+ x = (
+ x
+ + m * u.roll(shifts=-1, dims=-1)
+ - n * u
+ + (n - m) * u.roll(shifts=1, dims=-1)
+ )
+ seq.append(x)
+
+ if self.hard:
+ seq.reverse()
+
+ seq = torch.cat(seq, dim=1)
+ return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+ def compute_nb_correct(self, input, ar_mask, result):
+ nb_total = result.size(0)
+ nb_correct = 0
+ e = result.new_zeros(self.nb_state_tokens)
+
+ for seq in result:
+ states = list(seq.split(self.nb_state_tokens))
+ if self.hard:
+ states.reverse()
+
+ d = states[0]
+ j = d.sort(descending=True).indices[0]
+ e.zero_()
+ e[j] = self.value_max
+ if (d - e).abs().sum() == 0:
+ nb_errors = 0
+ for k in range(len(states) - 1):
+ d = states[k + 1] - states[k]
+ j = d.sort(descending=False).indices[0]
+ if (
+ d[j] == 0
+ or d[j] > self.value_max // 4
+ or d[(j + 1) % e.size(0)] <= 0
+ or d[(j + 1) % e.size(0)] >= -d[j]
+ ):
+ nb_errors += 1
+ else:
+ e.zero_()
+ e[j] = d[j]
+ e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
+ e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
+ if (d - e).abs().sum() > 0:
+ nb_errors += 1
+ if nb_errors == 0:
+ nb_correct += 1
+
+ return nb_total, nb_correct
+
+ def seq2str(self, seq):
+ return " | ".join(
+ [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
+ )
+
+
+####################
+
+
+class ProblemMemory(Problem):
+ def __init__(self, len_total=32):
+ self.len_total = len_total
+ self.max_len_pattern = 5
+ self.nb_noise_tokens = 10
+ self.start_pattern_token = 0
+ self.end_pattern_token = 1
+ self.start_result_token = 2
+ self.end_result_token = 3
+ self.token_string = "[]<>" + "".join(
+ [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
+ )
+
+ def generate_sequences(self, nb):
+ sequences = (
+ torch.randint(self.nb_noise_tokens, (nb, self.len_total))
+ + self.end_result_token
+ + 1
+ )
+ len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
+ pattern_positions = torch.randint(
+ self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
+ )
+ k = self.len_total - (3 + self.max_len_pattern)
+ for i in range(nb):
+ l = len_patterns[i]
+ j = pattern_positions[i]
+ sequences[i, j] = self.start_pattern_token
+ sequences[i, j + l + 2] = self.end_pattern_token
+ sequences[i, k] = self.start_result_token
+ sequences[i, k + l + 2] = self.end_result_token
+ sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
+
+ j = torch.arange(self.len_total)[None, :]
+ ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
+
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.token_string[x.item()] for x in seq)
+
+
+class ProblemTwoTargets(Problem):
+ def __init__(self, len_total=10, len_targets=3):
+ assert len_targets >= 3
+ assert len_total >= 3 * len_targets - 1
+ self.len_total = len_total
+ self.len_targets = len_targets
+
+ def generate_sequences(self, nb):
+ k = torch.arange(self.len_total)[None, :]
+ s = torch.randint(10, (nb, self.len_total))
+ l = torch.rand(nb, self.len_total)
+ l = l * (k <= self.len_total - self.len_targets).long()
+ k1 = l.argmax(dim=1, keepdim=True)
+ m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
+ s = s * m + 10 * (1 - m)
+ l = l * (
+ 1
+ - (k + self.len_targets - 1 >= k1).long()
+ * (k < k1 + self.len_targets).long()
+ )
+ k2 = l.argmax(dim=1, keepdim=True)
+ m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
+ s = s * m + 11 * (1 - m)
+ a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ sequences = torch.cat(
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,
+ )
+ ar_mask = (sequences == 12).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789-+|"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemByHeart(Problem):
+ def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
+ self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
+ self.seq[:, len_prompt] = 10
+
+ def generate_sequences(self, nb):
+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences == 10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789|"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemLearnOperator(Problem):
+ def __init__(self, nb_operators=100, len_source=6, len_result=9):
+ self.len_source = len_source
+ self.len_result = len_result
+ self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+ self.operators = F.one_hot(
+ torch.rand(nb_operators, len_result, len_source).argmax(-1),
+ num_classes=len_source,
+ )
+
+ def generate_sequences(self, nb):
+ nb_operators = torch.randint(self.operators.size(0), (nb,))
+ operators = self.operators[nb_operators]
+ nb_operators = (
+ nb_operators[:, None]
+ // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
+ ) % 10
+ marker1 = torch.full((nb, 1), 10)
+ source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ marker2 = torch.full((nb, 1), 11)
+ result = operators.bmm(source[:, :, None]).squeeze(-1)
+ sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
+ ar_mask = (sequences == 11).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemGuessOperator(Problem):
+ def __init__(self, len_source=5, len_result=8):
+ self.len_source = len_source
+ self.len_result = len_result
+
+ def generate_sequences(self, nb):
+ operators = F.one_hot(
+ torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+ num_classes=self.len_source,
+ )
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ marker1 = torch.full((nb, 1), 10)
+ result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+ marker2 = torch.full((nb, 1), 11)
+ source2 = torch.randint(10, (nb, self.len_source))
+ marker3 = torch.full((nb, 1), 12)
+ result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+ sequences = torch.cat(
+ (source1, marker1, result1, marker2, source2, marker3, result2), 1
+ )
+ ar_mask = (sequences == 12).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789>|~"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemAddition(Problem):
+ def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+ self.nb_digits = nb_digits
+ self.zero_padded = zero_padded
+ self.inverted_result = inverted_result
+ self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ def tensorize(self, strings):
+ len_max = max([len(x) for x in strings])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+ for s in strings
+ ]
+ )
+ ],
+ 0,
+ )
+
+ def generate_sequences(self, nb):
+ sequences = []
+ for k in range(nb):
+ a, b = torch.randint(10**self.nb_digits, (2,))
+ c = a + b
+ a, b, c = str(a.item()), str(b.item()), str(c.item())
+ if self.zero_padded:
+ a = "0" * (self.nb_digits - len(a)) + a
+ b = "0" * (self.nb_digits - len(b)) + b
+ c = "0" * (self.nb_digits + 1 - len(c)) + c
+ if self.inverted_result:
+ c = c[::-1]
+ sequences.append(f"{a}+{b}={c}$")
+
+ sequences = self.tensorize(sequences)
+ ar_mask = (sequences == self.char2id["="]).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.id2char[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemMixing(Problem):
+ def __init__(
+ self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
+ ):
+ self.height = height
+ self.width = width
+ self.nb_time_steps = nb_time_steps
+ self.hard = hard
+ self.random_start = random_start
+
+ def start_random(self, nb):
+ y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
+
+ if self.random_start:
+ i = (
+ torch.arange(self.height)
+ .reshape(1, -1, 1)
+ .expand(nb, self.height, self.width)
+ )
+ j = (
+ torch.arange(self.width)
+ .reshape(1, 1, -1)
+ .expand(nb, self.height, self.width)
+ )
+
+ ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+ rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+
+ m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+
+ y = y * m + self.height * self.width * (1 - m)
+
+ y = y.reshape(nb, self.height, self.width)
+
+ return y
+
+ def start_error(self, x):
+ if self.random_start:
+ i = (
+ torch.arange(self.height, device=x.device)
+ .reshape(1, -1, 1)
+ .expand_as(x)
+ )
+ j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
+
+ ri = (
+ (x == self.height * self.width)
+ .long()
+ .sum(dim=-1)
+ .argmax(-1)
+ .view(-1, 1, 1)
+ )
+ rj = (
+ (x == self.height * self.width)
+ .long()
+ .sum(dim=-2)
+ .argmax(-1)
+ .view(-1, 1, 1)
+ )
+
+ m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+ else:
+ m = 1
+
+ x = x.flatten(1)
+ u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
+
+ d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
+
+ return d
+
+ def moves(self, x):
+ y = (
+ x[:, None, :, :]
+ .expand(-1, self.height * 2 + self.width * 2, -1, -1)
+ .clone()
+ )
+ k = 0
+
+ for i in range(self.height):
+ y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
+ k += 1
+ y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
+ k += 1
+
+ for j in range(self.width):
+ y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
+ k += 1
+ y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
+ k += 1
+
+ return y
+
+ def generate_sequences(self, nb):
+ x = self.start_random(nb)
+
+ seq = [x.flatten(1)]
+
+ for t in range(self.nb_time_steps - 1):
+ y = self.moves(x)
+ x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+ seq.append(x.flatten(1))
+
+ if self.hard:
+ seq.reverse()
+
+ seq = torch.cat(seq, dim=1)
+ return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+ def compute_nb_correct(self, input, ar_mask, result):
+ a = [
+ x.reshape(result.size(0), self.height, self.width)
+ for x in result.split(self.height * self.width, dim=1)
+ ]
+ if self.hard:
+ a.reverse()
+
+ x = a[0]
+
+ d = self.start_error(x)
+
+ for t in range(self.nb_time_steps - 1):
+ x0, x = a[t], a[t + 1]
+ y = self.moves(x0)
+ d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+ nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
+
+ return nb_total, nb_correct
+
+ def seq2str(self, seq):
+ return " | ".join(
+ [
+ " ".join(
+ [
+ "-".join(
+ [
+ f"{x:02d}" if x < self.height * self.width else "**"
+ for x in s
+ ]
+ )
+ for s in r.split(self.width)
+ ]
+ )
+ for r in seq.split(self.height * self.width)
+ ]
+ )
+
+
+####################
+
+if __name__ == "__main__":
+ p = ProblemMixing(height=3, width=3, random_start=False)
+
+ s, m = p.generate_sequences(10000)
+ for x in s[:5]:
+ print(p.seq2str(x))
+ print(p.compute_nb_correct(None, None, s))
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch
+
+######################################################################
+
+
+class PScan(torch.autograd.Function):
+ # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+ # and O(log(T)) if not core-bounded, so that
+ #
+ # Y[:, 0] = Y_init
+ # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+ #
+ # can be computed as
+ #
+ # Y[:, t] = A[:, t] * Y_init + X[:, t]
+
+ @staticmethod
+ def expand_(A, X):
+ if A.size(1) == 1:
+ return
+ T = 2 * (A.size(1) // 2)
+ Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
+ Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
+ PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
+ Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+ Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+ if T < A.size(1):
+ X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+ A[:, -1].mul_(A[:, -2])
+
+ @staticmethod
+ def acc_rev_(A, X):
+ if X.size(1) == 1:
+ return
+ T = 2 * (X.size(1) // 2)
+ Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
+ Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
+ Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
+ B = Aa[:, :, 0].clone()
+ B[:, 1:].mul_(Aa[:, :-1, 1])
+ PScan.acc_rev_(B, Xa[:, :, 0])
+ Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
+ if T < A.size(1):
+ X[:, 0].add_(A[:, 1].mul(X[:, 1]))
+
+ # A is NxT, X is NxTxD, Y_init is NxD
+ #
+ # returns Y of same shape as X, with
+ #
+ # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0
+ # = A[:, t] * Y[:, t-1] + X[:, t] otherwise
+
+ @staticmethod
+ def forward(ctx, A, X, Y_init):
+ ctx.A = A.unsqueeze(-1).clone()
+ ctx.Y_init = Y_init[:, None].clone()
+ ctx.A_star = ctx.A.clone()
+ ctx.X_star = X.clone()
+ PScan.expand_(ctx.A_star, ctx.X_star)
+ return ctx.A_star * ctx.Y_init + ctx.X_star
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ U = grad_output * ctx.A_star
+ A = ctx.A.clone()
+ R = grad_output.clone()
+ PScan.acc_rev_(A, R)
+ Q = ctx.Y_init.expand_as(ctx.X_star).clone()
+ Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1])
+ return (Q * R).sum(-1), R, U.sum(dim=1)
+
+
+pscan = PScan.apply
+
+######################################################################
+
+if __name__ == "__main__":
+ import time, sys
+
+ A = torch.rand(17, 12, 3)
+ X = torch.rand(17, 12, 3, 11)
+ Y_init = torch.rand(17, 3, 11)
+ Y = pscan(A, X, Y_init)
+ exit(0)
+
+ N, T, D = 2, 1047, 3
+
+ A = torch.rand(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+
+ # Iterative implementation
+
+ y = Y_init
+ s = 0
+
+ for k in range(A.size(1)):
+ y = A[:, k, None] * y + X[:, k]
+ s = s + y
+
+ s = s.sum()
+
+ gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+ s, (A, X, Y_init), retain_graph=True
+ )
+
+ # parallel scan
+
+ start_time = time.perf_counter()
+ for _ in range(1000):
+ Y = pscan(A, X, Y_init)
+ duration = time.perf_counter() - start_time
+ print(f"duration {duration}")
+
+ s = Y.sum()
+
+ gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
+
+ # print(gA)
+ # print(gX)
+ # print(gY_init)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY_init - gY_init_ref).norm())
+
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())
--- /dev/null
+#!/usr/bin/env python
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_SEND: *.py *.sh
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+nb_quantization_levels = 101
+
+
+def quantize(x, xmin, xmax):
+ return (
+ ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
+ .long()
+ .clamp(min=0, max=nb_quantization_levels - 1)
+ )
+
+
+def dequantize(q, xmin, xmax):
+ return q / nb_quantization_levels * (xmax - xmin) + xmin
+
+
+######################################################################
+
+
+def generate_sets_and_params(
+ batch_nb_mlps,
+ nb_samples,
+ batch_size,
+ nb_epochs,
+ device=torch.device("cpu"),
+ print_log=False,
+ save_as_examples=False,
+):
+ data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
+ data_targets = torch.zeros(
+ batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
+ )
+
+ nb_rec = 8
+ nb_values = 2 # more increases the min-max gap
+
+ rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
+
+ while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
+ i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
+ nb = i.sum()
+ support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
+ support = support.sort(-1).values
+ support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
+
+ x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
+ y = (
+ (
+ (x[:, None, :, 0] >= support[:, :, None, 0]).long()
+ * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
+ * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
+ * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
+ )
+ .max(dim=1)
+ .values
+ )
+
+ data_input[i], data_targets[i], rec_support[i] = x, y, support
+
+ train_input, train_targets = (
+ data_input[:, :nb_samples],
+ data_targets[:, :nb_samples],
+ )
+ test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
+
+ q_train_input = quantize(train_input, -1, 1)
+ train_input = dequantize(q_train_input, -1, 1)
+
+ q_test_input = quantize(test_input, -1, 1)
+ test_input = dequantize(q_test_input, -1, 1)
+
+ if save_as_examples:
+ a = (
+ 2
+ * torch.arange(nb_quantization_levels).float()
+ / (nb_quantization_levels - 1)
+ - 1
+ )
+ xf = torch.cat(
+ [
+ a[:, None, None].expand(
+ nb_quantization_levels, nb_quantization_levels, 1
+ ),
+ a[None, :, None].expand(
+ nb_quantization_levels, nb_quantization_levels, 1
+ ),
+ ],
+ 2,
+ )
+ xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
+ print(f"{xf.size()=} {x.size()=}")
+ yf = (
+ (
+ (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
+ * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
+ * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
+ * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
+ )
+ .max(dim=1)
+ .values
+ )
+
+ full_input, full_targets = xf, yf
+
+ q_full_input = quantize(full_input, -1, 1)
+ full_input = dequantize(q_full_input, -1, 1)
+
+ for k in range(q_full_input[:10].size(0)):
+ with open(f"example_full_{k:04d}.dat", "w") as f:
+ for u, c in zip(full_input[k], full_targets[k]):
+ f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+ for k in range(q_train_input[:10].size(0)):
+ with open(f"example_train_{k:04d}.dat", "w") as f:
+ for u, c in zip(train_input[k], train_targets[k]):
+ f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+ hidden_dim = 32
+ w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
+ b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
+ w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
+ hidden_dim
+ )
+ b2 = torch.zeros(batch_nb_mlps, 2, device=device)
+
+ w1.requires_grad_()
+ b1.requires_grad_()
+ w2.requires_grad_()
+ b2.requires_grad_()
+ optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
+
+ criterion = nn.CrossEntropyLoss()
+ criterion.to(device)
+
+ for k in range(nb_epochs):
+ acc_train_loss = 0.0
+ nb_train_errors = 0
+
+ for input, targets in zip(
+ train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(
+ output.reshape(-1, output.size(-1)), targets.reshape(-1)
+ )
+ acc_train_loss += loss.item() * input.size(0)
+
+ wta = output.argmax(-1)
+ nb_train_errors += (wta != targets).long().sum(-1)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ with torch.no_grad():
+ for p in [w1, b1, w2, b2]:
+ m = (
+ torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
+ ).long()
+ pq = quantize(p, -2, 2)
+ p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
+
+ train_error = nb_train_errors / train_input.size(1)
+ acc_train_loss = acc_train_loss / train_input.size(1)
+
+ # print(f"{k=} {acc_train_loss=} {train_error=}")
+
+ acc_test_loss = 0
+ nb_test_errors = 0
+
+ for input, targets in zip(
+ test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+ acc_test_loss += loss.item() * input.size(0)
+
+ wta = output.argmax(-1)
+ nb_test_errors += (wta != targets).long().sum(-1)
+
+ test_error = nb_test_errors / test_input.size(1)
+ q_params = torch.cat(
+ [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
+ )
+ q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
+ batch_nb_mlps, -1
+ )
+ q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
+ batch_nb_mlps, -1
+ )
+
+ return q_train_set, q_test_set, q_params, test_error
+
+
+######################################################################
+
+
+def evaluate_q_params(
+ q_params,
+ q_set,
+ batch_size=25,
+ device=torch.device("cpu"),
+ nb_mlps_per_batch=1024,
+ save_as_examples=False,
+):
+ errors = []
+ nb_mlps = q_params.size(0)
+
+ for n in range(0, nb_mlps, nb_mlps_per_batch):
+ batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
+ batch_q_params = q_params[n : n + batch_nb_mlps]
+ batch_q_set = q_set[n : n + batch_nb_mlps]
+ hidden_dim = 32
+ w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
+ b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
+ w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
+ b2 = torch.empty(batch_nb_mlps, 2, device=device)
+
+ with torch.no_grad():
+ k = 0
+ for p in [w1, b1, w2, b2]:
+ print(f"{p.size()=}")
+ x = dequantize(
+ batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
+ ).view(p.size())
+ p.copy_(x)
+ k += p.numel() // batch_nb_mlps
+
+ batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
+ data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
+ data_targets = batch_q_set[:, :, 2].to(device)
+
+ print(f"{data_input.size()=} {data_targets.size()=}")
+
+ criterion = nn.CrossEntropyLoss()
+ criterion.to(device)
+
+ acc_loss = 0.0
+ nb_errors = 0
+
+ for input, targets in zip(
+ data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(
+ output.reshape(-1, output.size(-1)), targets.reshape(-1)
+ )
+ acc_loss += loss.item() * input.size(0)
+ wta = output.argmax(-1)
+ nb_errors += (wta != targets).long().sum(-1)
+
+ errors.append(nb_errors / data_input.size(1))
+ acc_loss = acc_loss / data_input.size(1)
+
+ return torch.cat(errors)
+
+
+######################################################################
+
+
+def generate_sequence_and_test_set(
+ nb_mlps,
+ nb_samples,
+ batch_size,
+ nb_epochs,
+ device,
+ nb_mlps_per_batch=1024,
+):
+ seqs, q_test_sets, test_errors = [], [], []
+
+ for n in range(0, nb_mlps, nb_mlps_per_batch):
+ q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
+ batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
+ nb_samples=nb_samples,
+ batch_size=batch_size,
+ nb_epochs=nb_epochs,
+ device=device,
+ )
+
+ seqs.append(
+ torch.cat(
+ [
+ q_train_set,
+ q_train_set.new_full(
+ (
+ q_train_set.size(0),
+ 1,
+ ),
+ nb_quantization_levels,
+ ),
+ q_params,
+ ],
+ dim=-1,
+ )
+ )
+
+ q_test_sets.append(q_test_set)
+ test_errors.append(test_error)
+
+ seq = torch.cat(seqs)
+ q_test_set = torch.cat(q_test_sets)
+ test_error = torch.cat(test_errors)
+
+ return seq, q_test_set, test_error
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import time
+
+ batch_nb_mlps, nb_samples = 128, 250
+
+ generate_sets_and_params(
+ batch_nb_mlps=10,
+ nb_samples=nb_samples,
+ batch_size=25,
+ nb_epochs=100,
+ device=torch.device("cpu"),
+ print_log=False,
+ save_as_examples=True,
+ )
+
+ exit(0)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ start_time = time.perf_counter()
+
+ data = []
+
+ seq, q_test_set, test_error = generate_sequence_and_test_set(
+ nb_mlps=batch_nb_mlps,
+ nb_samples=nb_samples,
+ device=device,
+ batch_size=25,
+ nb_epochs=250,
+ nb_mlps_per_batch=17,
+ )
+
+ end_time = time.perf_counter()
+ print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
+
+ q_train_set = seq[:, : nb_samples * 3]
+ q_params = seq[:, nb_samples * 3 + 1 :]
+ print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
+ error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
+ print(f"train {error_train*100}%")
+ error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+ print(f"test {error_test*100}%")
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+def rpl_exec(program, stack):
+ stack = stack.copy()
+ for op in program:
+ if op == "add":
+ if len(stack) > 1:
+ a, b = stack.pop(), stack.pop()
+ stack.append(a + b)
+ elif op == "min":
+ if len(stack) > 1:
+ a, b = stack.pop(), stack.pop()
+ stack.append(min(a, b))
+ elif op == "max":
+ if len(stack) > 1:
+ a, b = stack.pop(), stack.pop()
+ stack.append(max(a, b))
+ elif op == "swp":
+ if len(stack) > 1:
+ a, b = stack.pop(), stack.pop()
+ stack.append(a)
+ stack.append(b)
+ elif op == "rep":
+ if len(stack) > 1:
+ a, b = stack.pop(), stack.pop()
+ stack += [b] * a
+ elif op == "dup":
+ if len(stack) > 0:
+ a = stack.pop()
+ stack.append(a)
+ stack.append(a)
+ elif op == "del":
+ if len(stack) > 0:
+ a = stack.pop()
+ else:
+ raise ValueError(f"Unknown instruction {op}")
+
+ return stack
+
+
+rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"]
+
+######################################################################
+
+
+def generate(
+ nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
+):
+ prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
+
+ while True:
+ no_empty_stack = True
+ prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))]
+
+ result = []
+ for _ in range(nb_runs):
+ stack = [
+ x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))
+ ]
+ result_stack = rpl_exec(prog, stack)
+ if len(result_stack) == 0:
+ no_empty_stack = False
+ result = result + ["<in>"] + stack + ["<out>"] + result_stack
+
+ result = result + ["<prg>"] + prog
+ result = result + ["<end>"]
+
+ if no_empty_stack and (
+ nb_result_values_max is None or len(result_stack) <= nb_result_values_max
+ ):
+ break
+
+ return result
+
+
+def next_marker(seq, tokens, start=0):
+ pos = None
+ for t in tokens:
+ try:
+ i = seq.index(t, start)
+ if pos is None or i < pos:
+ pos = i
+ except ValueError:
+ pass
+ return pos
+
+
+def decompose(seq):
+ io = []
+ k = 0
+ while seq[k] == "<in>":
+ o = next_marker(seq, ["<out>"], start=k + 1)
+ if o is None:
+ raise ValueError("Missing output markers (should be correct in the prompt)")
+ e = next_marker(seq, ["<in>", "<prg>"], start=o)
+ if e is None:
+ raise ValueError(
+ "Missing input/output markers (should be correct in the prompt)"
+ )
+ try:
+ io.append(
+ ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]])
+ )
+ except ValueError:
+ raise ValueError(
+ "Invalid input/output value (should be correct in the prompt)"
+ )
+
+ k = e
+
+ if seq[k] == "<prg>":
+ e = next_marker(seq, ["<end>"], start=k)
+ if e is None:
+ prog = []
+ else:
+ prog = seq[k + 1 : e]
+ else:
+ raise ValueError("Missing <prg> (it should be in the prompt)")
+
+ return prog, io
+
+
+def stack_distance(target_stack, result_stack):
+ return abs(len(result_stack) - len(target_stack)) + sum(
+ [0 if x == y else 1 for x, y in zip(result_stack, target_stack)]
+ )
+
+
+def compute_nb_errors(seq):
+ prog, io = decompose(seq)
+
+ nb_total, nb_errors = 0, 0
+
+ stacks = []
+
+ if len(set(prog) - set(rpl_ops)) > 0:
+ # Program is not valid, we count 100% error
+ for start_stack, target_stack in io:
+ stacks.append((start_stack, target_stack, ["N/A"], False))
+ nb_total += len(target_stack)
+ nb_errors += len(target_stack)
+
+ else:
+ # Program is valid
+ for start_stack, target_stack in io:
+ result_stack = rpl_exec(prog, start_stack)
+ nb_total += len(target_stack)
+ e = stack_distance(target_stack, result_stack)
+ nb_errors += e
+ stacks.append((start_stack, target_stack, result_stack, e == 0))
+
+ return nb_total, nb_errors, prog, stacks
+
+
+######################################################################
+
+if __name__ == "__main__":
+ seq = generate()
+ print(seq)
+ seq[3] = 7
+ print(seq)
+ print(compute_nb_errors(seq))
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+import torch.nn.functional as F
+
+
+def generate_sequences(
+ nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
+):
+ worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+ world_prior_visits = torch.zeros(nb, height, width, device=device)
+
+ # nb x 2
+ snake_position = torch.cat(
+ (
+ torch.randint(height, (nb, 1), device=device),
+ torch.randint(width, (nb, 1), device=device),
+ ),
+ 1,
+ )
+ snake_direction = torch.randint(4, (nb,), device=device)
+ sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+ sequences_prior_visits = torch.zeros(
+ nb, 2 * length, device=device, dtype=torch.int64
+ )
+ i = torch.arange(nb, device=device) # [:,None]
+
+ for l in range(length):
+ # nb x 3
+ snake_next_direction = torch.cat(
+ (
+ (snake_direction[:, None] - 1) % 4,
+ snake_direction[:, None],
+ (snake_direction[:, None] + 1) % 4,
+ ),
+ 1,
+ )
+
+ # nb x 3
+ vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
+ vw = snake_next_direction % 2 * (snake_next_direction - 2)
+
+ # nb x 3 x 2
+ snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
+ snake_next_position = snake_position[:, None, :] + snake_next_speed
+
+ # nb x 3
+ val = torch.logical_and(
+ torch.logical_and(
+ snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
+ ),
+ torch.logical_and(
+ snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
+ ),
+ ).float()
+ val = (
+ # The multiplicative factors bias toward moving forward
+ torch.rand_like(val)
+ * val
+ * torch.tensor([[1.0, 2.0, 1.0]], device=device)
+ )
+
+ # nb
+ j = val.argmax(1)
+ snake_direction = snake_next_direction[i, j]
+
+ sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+ sequences_prior_visits[:, 2 * l] = world_prior_visits[
+ i, snake_position[:, 0], snake_position[:, 1]
+ ]
+ if l < prompt_length:
+ world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+ sequences[:, 2 * l + 1] = snake_direction
+
+ # nb x 2
+ snake_position = snake_next_position[i, j]
+
+ return sequences, sequences_prior_visits, worlds, world_prior_visits
+
+
+# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
+# exit(0)
+
+
+def solver(input, ar_mask):
+ for n in range(input.size(0)):
+ i, j, memory = 0, 0, {}
+ # print(input[n])
+ # print(ar_mask[n])
+ for l in range(input.size(1) // 2):
+ if ar_mask[n, 2 * l] == 1:
+ if memory.get((i, j)) is None:
+ input[n, 2 * l] = -1
+ else:
+ input[n, 2 * l] = memory[(i, j)]
+ else:
+ # print(f'@3 {memory=}')
+ if memory.get((i, j)) is None:
+ memory[(i, j)] = input[n, 2 * l]
+ else:
+ assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
+ # print(f'@1 {i=} {j=}')
+ d = input[n, 2 * l + 1].item()
+ i += (d + 1) % 2 * (d - 1)
+ j += d % 2 * (d - 2)
+ # print(f'@2 {i=} {j=}')
+
+
+def seq2str(seq):
+ return "".join(["NESW123456789"[i] for i in seq])
+
+
+######################################################################
+
+if __name__ == "__main__":
+ train_input, train_prior_visits, _, _ = generate_sequences(
+ nb=20,
+ height=9,
+ width=12,
+ nb_colors=5,
+ length=50,
+ prompt_length=100,
+ )
+
+ print([seq2str(s) for s in train_input])
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
+# CODE_VAL=val + 2 * nb_stacks
+
+
+def generate_sequences(
+ nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
+ stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
+ stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+ k = torch.arange(nb)
+ result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
+ recorded_stack_counts = torch.zeros(
+ nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
+ )
+
+ for t in range(nb_steps):
+ op = torch.randint(2, (nb,))
+ st = torch.randint(nb_stacks, (nb,))
+ op = op * (stack_counts[k, st] > 0)
+ if values is None:
+ val_push = torch.randint(10**nb_digits, (nb,))
+ else:
+ val_push = values[torch.randint(values.size(0), (nb,))]
+ val_pop = stack[
+ k,
+ st,
+ (stack_counts[k, st] - 1).clamp(min=0),
+ ]
+ stack[k, st, stack_counts[k, st]] = val_push
+ recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+ stack_counts[k[op == 0], st[op == 0]] += 1
+ stack_counts[k[op == 1], st[op == 1]] -= 1
+ result[:, (1 + nb_digits) * t] = st * 2 + op
+ for d in range(nb_digits):
+ result[:, (1 + nb_digits) * t + 1 + d] = (
+ (op * val_pop + (1 - op) * val_push) // (10**d)
+ ) % 10 + 2 * nb_stacks
+
+ return result.to(device), recorded_stack_counts.to(device)
+
+
+def remove_popped_values(seq, nb_stacks, nb_digits):
+ m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
+ for d in range(nb_digits):
+ k = d + 1
+ seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
+
+
+def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
+ assert seq.size(0) % (1 + nb_digits) == 0
+ s = ""
+ for t in range(seq.size(0) // (1 + nb_digits)):
+ n_op = seq[(1 + nb_digits) * t]
+ if t > 0:
+ s += " "
+ if recorded_stack_counts is not None:
+ s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
+ s += f"POP" if n_op % 2 == 1 else f"PSH"
+ if nb_stacks > 1:
+ s += f"_{n_op//2}"
+ for d in range(nb_digits):
+ if seq[(1 + nb_digits) * t + 1 + d] == -1:
+ s += " ?"
+ else:
+ s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
+ return s
+
+
+######################################################################
+
+if __name__ == "__main__":
+ nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
+ seq, recorded_stack_counts = generate_sequences(
+ nb=nb,
+ nb_steps=nb_steps,
+ nb_stacks=nb_stacks,
+ nb_digits=nb_digits,
+ )
+
+ for n in range(min(10, seq.size(0))):
+ print(
+ seq_to_str(
+ seq[n],
+ nb_stacks=nb_stacks,
+ nb_digits=nb_digits,
+ recorded_stack_counts=recorded_stack_counts[n],
+ )
+ )
+ # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
+
+ print("-- PREPARED FOR TEST -----------------")
+
+ remove_popped_values(seq, nb_stacks, nb_digits)
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, os, tqdm
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+from mygpt import BracketedSequence
+
+# from graph import save_attention_image
+save_attention_image = None
+
+######################################################################
+
+
+def masked_inplace_autoregression(
+ model,
+ batch_size,
+ input,
+ ar_mask,
+ deterministic_synthesis,
+ forbidden_tokens=None,
+ progress_bar_desc="autoregression",
+ device=torch.device("cpu"),
+):
+ assert input.size() == ar_mask.size()
+
+ batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+
+ if progress_bar_desc is not None:
+ batches = tqdm.tqdm(
+ batches,
+ dynamic_ncols=True,
+ desc=progress_bar_desc,
+ total=(input.size(0) + batch_size - 1) // batch_size,
+ )
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, ar_mask in batches:
+ model.masked_inplace_autoregression(
+ input, ar_mask, forbidden_tokens, deterministic_synthesis
+ )
+
+ model.train(t)
+
+
+######################################################################
+
+
+class Task:
+ def batches(self, split="train"):
+ pass
+
+ def vocabulary_size(self):
+ pass
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ pass
+
+
+####################
+
+import problems
+
+
+class SandBox(Task):
+ def __init__(
+ self,
+ problem,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ logger=None,
+ device=torch.device("cpu"),
+ max_nb_codes=1024,
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.device = device
+ self.problem = problem
+
+ self.train_input, self.train_ar_mask = self.problem.generate_sequences(
+ nb_train_samples
+ )
+ self.test_input, self.test_ar_mask = self.problem.generate_sequences(
+ nb_test_samples
+ )
+
+ self.train_input, self.train_ar_mask = self.train_input.to(
+ device
+ ), self.train_ar_mask.to(device)
+ self.test_input, self.test_ar_mask = self.test_input.to(
+ device
+ ), self.test_ar_mask.to(device)
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ # A bit of paranoia never hurts
+ assert self.nb_codes <= max_nb_codes
+ assert self.train_input.min() >= 0
+ assert self.test_input.min() >= 0
+ assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
+ (0,),
+ (1,),
+ (0, 1),
+ }
+ assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
+ (0,),
+ (1,),
+ (0, 1),
+ }
+
+ if logger is not None:
+ for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
+ logger(f"train_sequences {self.problem.seq2str(s)}")
+ a = "".join(["01"[x.item()] for x in a])
+ logger(f" {a}")
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
+ ):
+ def compute_accuracy(input, ar_mask, logger=None):
+ input, ar_mask = input[:nmax], ar_mask[:nmax]
+ result = input.clone() * (1 - ar_mask)
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
+ )
+
+ log_ground_truth = ar_mask.min() == 0
+
+ if logger is not None:
+ for sp, st in zip(result[:10], input[:10]):
+ logger(
+ f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}"
+ )
+ if log_ground_truth:
+ logger(
+ f" {n_epoch} ground truth {self.problem.seq2str(st)}"
+ )
+
+ nb_total, nb_correct = self.problem.compute_nb_correct(
+ input, ar_mask, result
+ )
+
+ # nb_total = ar_mask.sum().item()
+ # nb_correct = ((result == input).long() * ar_mask).sum().item()
+
+ return nb_total, nb_correct
+
+ train_nb_total, train_nb_correct = compute_accuracy(
+ self.train_input, self.train_ar_mask
+ )
+
+ logger(
+ f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+ )
+
+ test_nb_total, test_nb_correct = compute_accuracy(
+ self.test_input, self.test_ar_mask, logger
+ )
+
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+ if save_attention_image is not None:
+ for k in range(10):
+ ns = torch.randint(self.test_input.size(0), (1,)).item()
+ input = self.test_input[ns : ns + 1].clone()
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ # model.record_attention(True)
+ model(BracketedSequence(input))
+ model.train(t)
+ # ram = model.retrieve_attention()
+ # model.record_attention(False)
+
+ # tokens_output = [c for c in self.problem.seq2str(input[0])]
+ # tokens_input = ["n/a"] + tokens_output[:-1]
+ # for n_head in range(ram[0].size(1)):
+ # filename = os.path.join(
+ # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
+ # )
+ # attention_matrices = [m[0, n_head] for m in ram]
+ # save_attention_image(
+ # filename,
+ # tokens_input,
+ # tokens_output,
+ # attention_matrices,
+ # k_top=10,
+ ##min_total_attention=0.9,
+ # token_gap=12,
+ # layer_gap=50,
+ # )
+ # logger(f"wrote {filename}")
+
+
+######################################################################
+
+import picoclvr
+
+
+class PicoCLVR(Task):
+ # Make a tensor from a list of strings
+ def tensorize(self, descr):
+ token_descr = [s.strip().split(" ") for s in descr]
+ l = max([len(s) for s in token_descr])
+ token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+ id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+ return torch.tensor(id_descr, device=self.device)
+
+ # Make a list of strings from a tensor
+ def detensorize(self, x):
+ return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+ # trim all the tensors in the tuple z to remove as much token from
+ # left and right in the first tensor. If z is a tuple, all its
+ # elements are trimed according to the triming for the first
+ def trim(self, z, token="<nul>"):
+ n = self.token2id[token]
+ if type(z) == tuple:
+ x = z[0]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return tuple([t[:, a:b] for t in z])
+ else:
+ i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return z[:, a:b]
+
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ nb_colors=5,
+ logger=None,
+ device=torch.device("cpu"),
+ pruner_train=None,
+ pruner_eval=None,
+ ):
+ super().__init__()
+
+ def generate_descr(nb, cache_suffix, pruner):
+ return picoclvr.generate(
+ nb,
+ height=self.height,
+ width=self.width,
+ nb_colors=nb_colors,
+ pruner=pruner,
+ )
+
+ self.height = height
+ self.width = width
+ self.batch_size = batch_size
+ self.device = device
+ self.pruner_train = pruner_train
+ self.pruner_eval = pruner_eval
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ self.train_descr = generate_descr(
+ nb_train_samples, "train", pruner=self.pruner_train
+ )
+ self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
+
+ # Build the tokenizer
+ tokens = {"<nul>", "<img>"}
+ for d in [self.train_descr, self.test_descr]:
+ for s in d:
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ # make this set a sorted list to get the same tensors given
+ # the same descr
+ tokens = list(tokens)
+ tokens.sort()
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+ self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+
+ # Tokenize the train and test sets
+ self.train_input = self.tensorize(self.train_descr)
+ self.test_input = self.tensorize(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield self.trim(batch)
+
+ def vocabulary_size(self):
+ return len(self.token2id)
+
+ def compute_missing_properties(
+ self, n_epoch, model, logger, deterministic_synthesis, pruner=None
+ ):
+ acc_nb_requested_properties = []
+ acc_nb_missing_properties = []
+ acc_nb_results = 0
+
+ for input in tqdm.tqdm(
+ self.test_input.split(self.batch_size),
+ dynamic_ncols=True,
+ desc=f"test-properties",
+ ):
+ result = input.clone()
+ ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
+ )
+
+ result_descr = self.detensorize(result)
+ np = picoclvr.nb_properties(
+ result_descr,
+ height=self.height,
+ width=self.width,
+ pruner=pruner,
+ )
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+ acc_nb_requested_properties += nb_requested_properties
+ acc_nb_missing_properties += nb_missing_properties
+ acc_nb_results += len(result_descr)
+
+ nb_requested_properties = sum(acc_nb_requested_properties)
+ nb_missing_properties = sum(acc_nb_missing_properties)
+
+ prefix = "" if pruner is None else "pruned_"
+ logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+ logger(
+ f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+ )
+ logger(
+ f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+ )
+
+ logger(
+ f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
+ )
+
+ ######################################################################
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
+
+ if self.pruner_eval is not None:
+ self.compute_missing_properties(n_epoch, model, self.pruner_eval)
+
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = []
+ nb_per_primer = 8
+ primer = []
+
+ for primer_descr in [
+ "red above green <sep> green top <sep> blue right of red",
+ "there is red <sep> there is yellow <sep> there is blue",
+ "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
+ "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
+ ]:
+ primer += [primer_descr + " <img>"] * nb_per_primer
+
+ result = self.tensorize(primer)
+ fill = result.new_full(
+ result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
+ )
+ result = torch.cat((result, fill), 1)
+ ar_mask = (result == self.t_nul).long()
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+ result_descr = self.detensorize(result)
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
+ acc_nb_results = len(result_descr)
+
+ nb_requested_properties = sum(acc_nb_requested_properties)
+ nb_missing_properties = sum(acc_nb_missing_properties)
+
+ prefix = "demo_"
+ logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+ logger(
+ f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+ )
+ logger(
+ f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+ )
+
+ img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
+
+ if img.dim() == 5:
+ if img.size(1) == 1:
+ img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
+ else:
+ img = torch.cat(
+ [
+ torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
+ for x in img
+ ],
+ 0,
+ )
+
+ image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
+ )
+ logger(f"wrote {image_name}")
+
+
+######################################################################
+
+
+class MNIST(Task):
+ def __init__(
+ self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+ ):
+ super().__init__()
+
+ self.nb_train_samples = (nb_train_samples,)
+ self.nb_test_samples = (nb_test_samples,)
+ self.batch_size = batch_size
+ self.device = device
+ data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
+ self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
+ data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
+ self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return 256
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
+ ar_mask = torch.full_like(results, 1)
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ results,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+ image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
+ torchvision.utils.save_image(
+ 1 - results.reshape(-1, 1, 28, 28) / 255.0,
+ image_name,
+ nrow=16,
+ pad_value=0.8,
+ )
+ logger(f"wrote {image_name}")
+
+
+######################################################################
+
+import maze
+
+
+class Maze(Task):
+ def map2seq(self, *m):
+ return torch.cat([x.flatten(1) for x in m], 1)
+
+ def seq2map(self, s):
+ s = s.reshape(s.size(0), -1, self.height, self.width)
+ return (s[:, k] for k in range(s.size(1)))
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ nb_walls,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.height = height
+ self.width = width
+ self.device = device
+
+ train_mazes, train_paths, _ = maze.create_maze_data(
+ nb_train_samples,
+ height=height,
+ width=width,
+ nb_walls=nb_walls,
+ progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
+ )
+ self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
+
+ test_mazes, test_paths, _ = maze.create_maze_data(
+ nb_test_samples,
+ height=height,
+ width=width,
+ nb_walls=nb_walls,
+ progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
+ )
+ self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def compute_error(
+ self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
+ ):
+ nb_total, nb_correct = 0, 0
+ count = torch.zeros(
+ self.width * self.height,
+ self.width * self.height,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ for input in self.batches(split, nb_to_use):
+ result = input.clone()
+ ar_mask = result.new_zeros(result.size())
+ ar_mask[:, self.height * self.width :] = 1
+ result *= 1 - ar_mask
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
+ )
+ mazes, paths = self.seq2map(result)
+ path_correctness = maze.path_correctness(mazes, paths)
+ nb_correct += path_correctness.long().sum()
+ nb_total += mazes.size(0)
+
+ optimal_path_lengths = (
+ (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+ )
+ predicted_path_lengths = (
+ (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+ )
+ optimal_path_lengths = optimal_path_lengths[path_correctness]
+ predicted_path_lengths = predicted_path_lengths[path_correctness]
+ count[optimal_path_lengths, predicted_path_lengths] += 1
+
+ if count.max() == 0:
+ count = None
+ else:
+ count = count[
+ : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+ ]
+
+ return nb_total, nb_correct, count
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ train_nb_total, train_nb_correct, count = self.compute_error(
+ model,
+ "train",
+ nb_to_use=1000,
+ deterministic_synthesis=deterministic_synthesis,
+ )
+ logger(
+ f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+ )
+
+ test_nb_total, test_nb_correct, count = self.compute_error(
+ model,
+ "test",
+ nb_to_use=1000,
+ deterministic_synthesis=deterministic_synthesis,
+ )
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+ if count is not None:
+ proportion_optimal = count.diagonal().sum().float() / count.sum()
+ logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+ with open(
+ os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+ ) as f:
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ eol = " " if j < count.size(1) - 1 else "\n"
+ f.write(f"{count[i,j]}{eol}")
+
+ input = self.test_input[:48]
+ result = input.clone()
+ ar_mask = result.new_zeros(result.size())
+ ar_mask[:, self.height * self.width :] = 1
+ result *= 1 - ar_mask
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ mazes, paths = self.seq2map(input)
+ _, predicted_paths = self.seq2map(result)
+
+ filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
+ maze.save_image(
+ filename,
+ mazes=mazes,
+ target_paths=paths,
+ predicted_paths=predicted_paths,
+ path_correct=maze.path_correctness(mazes, predicted_paths),
+ path_optimal=maze.path_optimality(paths, predicted_paths),
+ )
+ logger(f"wrote {filename}")
+
+
+######################################################################
+
+
+import snake
+
+
+class Snake(Task):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ nb_colors,
+ length,
+ prompt_length,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.height = height
+ self.width = width
+ self.device = device
+ self.prompt_length = prompt_length
+
+ self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
+ nb_train_samples,
+ height,
+ width,
+ nb_colors,
+ length,
+ prompt_length,
+ self.device,
+ )
+ self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
+ nb_test_samples,
+ height,
+ width,
+ nb_colors,
+ length,
+ prompt_length,
+ self.device,
+ )
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ def compute_nb_correct(input, prior_visits):
+ result = input.clone()
+ i = torch.arange(result.size(1), device=result.device)[None, :]
+ ar_mask = (
+ torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
+ .long()
+ .expand_as(result)
+ )
+ result *= 1 - ar_mask
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ nb_total = ((prior_visits > 0) * ar_mask).sum()
+
+ nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
+
+ return nb_total, nb_correct
+
+ test_nb_total, test_nb_correct = compute_nb_correct(
+ self.test_input[:1000], self.test_prior_visits[:1000]
+ )
+
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+
+######################################################################
+
+
+import stack
+
+
+class Stack(Task):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ logger,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ fraction_values_for_train=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.nb_steps = nb_steps
+ self.nb_stacks = nb_stacks
+ self.nb_digits = nb_digits
+ self.device = device
+
+ if fraction_values_for_train is None:
+ values_for_train = None
+ values_for_test = None
+ else:
+ all = torch.randperm(10**nb_digits)
+ nb_for_train = int(all.size(0) * fraction_values_for_train)
+ values_for_train = all[:nb_for_train]
+ values_for_test = all[nb_for_train:]
+
+ self.train_input, self.train_stack_counts = stack.generate_sequences(
+ nb_train_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_train,
+ self.device,
+ )
+
+ self.test_input, self.test_stack_counts = stack.generate_sequences(
+ nb_test_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_test,
+ self.device,
+ )
+
+ i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
+ counts = self.test_stack_counts.flatten()[i.flatten()]
+ counts = F.one_hot(counts).sum(0)
+ logger(f"test_pop_stack_counts {counts}")
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ def compute_nb_correct(input):
+ result = input.clone()
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+ ar_mask = (result != input).long()
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ errors = ((result != input).long() * ar_mask).reshape(
+ -1, 1 + self.nb_digits
+ )
+ ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
+
+ nb_total = ar_mask.max(1).values.sum()
+ nb_correct = nb_total - errors.max(1).values.sum()
+
+ return nb_total, nb_correct
+
+ test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+ ##############################################################
+ # Log a few generated sequences
+ input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
+ result = input.clone()
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+ ar_mask = (result != input).long()
+
+ # for n in range(result.size(0)):
+ # logger(
+ # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ # )
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ for n in range(result.size(0)):
+ logger(
+ f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+ ##############################################################
+
+
+######################################################################
+
+import rpl
+
+
+class RPL(Task):
+ def tensorize(self, sequences):
+ len_max = max([len(x) for x in sequences])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [
+ self.token2id[str(c)]
+ for c in s + ["<nul>"] * (len_max - len(s))
+ ]
+ for s in sequences
+ ]
+ )
+ ],
+ 0,
+ )
+
+ def seq2str(self, seq):
+ return " ".join([self.id2token[i] for i in seq])
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ nb_starting_values=3,
+ max_input=9,
+ prog_len=6,
+ nb_runs=5,
+ no_prog=False,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.device = device
+ self.no_prog = no_prog
+
+ train_sequences = [
+ rpl.generate(
+ nb_starting_values=nb_starting_values,
+ nb_result_values_max=4 * nb_starting_values,
+ max_input=max_input,
+ prog_len=prog_len,
+ nb_runs=nb_runs,
+ )
+ for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
+ ]
+
+ test_sequences = [
+ rpl.generate(
+ nb_starting_values=nb_starting_values,
+ nb_result_values_max=4 * nb_starting_values,
+ max_input=max_input,
+ prog_len=prog_len,
+ nb_runs=nb_runs,
+ )
+ for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
+ ]
+
+ symbols = list(
+ set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
+ )
+ val_max = max([x if type(x) is int else 0 for x in symbols])
+ symbols = list(filter(lambda x: type(x) is str, symbols))
+ symbols.sort()
+ symbols += [str(n) for n in range(val_max + 1)]
+ self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
+ self.id2token = dict([(n, c) for c, n in self.token2id.items()])
+
+ self.t_nul = self.token2id["<nul>"]
+ self.t_input = self.token2id["<in>"]
+ self.t_output = self.token2id["<out>"]
+ self.t_prog = self.token2id["<prg>"]
+ self.t_end = self.token2id["<end>"]
+
+ self.train_input = self.tensorize(train_sequences)
+ self.test_input = self.tensorize(test_sequences)
+
+ if no_prog:
+ # Excise the program from every train and test example
+ k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
+ None, :
+ ]
+ p = (
+ ((self.train_input == self.t_prog).long() * k)
+ .max(1, keepdim=True)
+ .values
+ )
+ self.train_input = (
+ self.train_input * (k <= p).long()
+ + self.t_end * (k == p + 1).long()
+ + self.t_nul * (k > p + 1).long()
+ )
+ k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
+ None, :
+ ]
+ p = (
+ ((self.test_input == self.t_prog).long() * k)
+ .max(1, keepdim=True)
+ .values
+ )
+ self.test_input = (
+ self.test_input * (k <= p).long()
+ + self.t_end * (k == p + 1).long()
+ + self.t_nul * (k > p + 1).long()
+ )
+
+ if logger is not None:
+ logger(f"value_max {val_max}")
+ for x in self.train_input[:25]:
+ end = (x != self.t_nul).nonzero().max().item() + 1
+ seq = [self.id2token[i.item()] for i in x[:end]]
+ s = " ".join(seq)
+ logger(f"example_seq {s}")
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
+ batch = batch[:, :last].to(self.device)
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ # --------------------------------------------------------------------
+ def compute_nb_errors_prog(input, nb_to_log=0):
+ result = input.clone()
+ s = (result == self.t_prog).long()
+ ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ sum_nb_total, sum_nb_errors = 0, 0
+ for one_input, one_result in zip(input, result):
+ seq = [self.id2token[i.item()] for i in one_result]
+ nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
+ sum_nb_total += 1
+ sum_nb_errors += 0 if nb_errors == 0 else 1
+ if nb_to_log > 0:
+ gt_seq = [self.id2token[i.item()] for i in one_input]
+ _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
+ gt_prog = " ".join([str(x) for x in gt_prog])
+ prog = " ".join([str(x) for x in prog])
+ comment = "*" if nb_errors == 0 else "-"
+ logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
+ for start_stack, target_stack, result_stack, correct in stacks:
+ comment = "*" if correct else "-"
+ start_stack = " ".join([str(x) for x in start_stack])
+ target_stack = " ".join([str(x) for x in target_stack])
+ result_stack = " ".join([str(x) for x in result_stack])
+ logger(
+ f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
+ )
+ nb_to_log -= 1
+
+ return sum_nb_total, sum_nb_errors
+
+ # --------------------------------------------------------------------
+ def compute_nb_errors_output(input, nb_to_log=0):
+ result = input.clone()
+ k = torch.arange(result.size(1), device=result.device)[None, :]
+ last_output_idx = (
+ ((result == self.t_output) * k).max(dim=1, keepdim=True).values
+ )
+ first_prog_idx = (
+ ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
+ )
+ ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ sum_nb_total, sum_nb_errors = 0, 0
+ for one_input, one_result, i, j in zip(
+ input, result, last_output_idx, first_prog_idx
+ ):
+ seq = [self.id2token[i.item()] for i in one_result]
+ sum_nb_total += 1
+ correct = (one_input - one_result).abs().max() == 0
+ sum_nb_errors += 0 if correct else 1
+ if nb_to_log > 0:
+ result_stack = [
+ self.id2token[i.item()] for i in one_result[i : j + 1]
+ ]
+ target_stack = [
+ self.id2token[i.item()] for i in one_input[i : j + 1]
+ ]
+ comment = "*" if correct else "-"
+ result_stack = " ".join([str(x) for x in result_stack])
+ target_stack = " ".join([str(x) for x in target_stack])
+ logger(
+ f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
+ )
+ nb_to_log -= 1
+
+ return sum_nb_total, sum_nb_errors
+
+ # --------------------------------------------------------------------
+
+ if not self.no_prog:
+ test_nb_total, test_nb_errors = compute_nb_errors_prog(
+ self.test_input[:1000].to(self.device), nb_to_log=10
+ )
+
+ logger(
+ f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
+
+ test_nb_total, test_nb_errors = compute_nb_errors_output(
+ self.test_input[:1000].to(self.device), nb_to_log=10
+ )
+
+ logger(
+ f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
+ )
+
+ if save_attention_image is not None:
+ ns = torch.randint(self.test_input.size(0), (1,)).item()
+ input = self.test_input[ns : ns + 1].clone()
+ last = (input != self.t_nul).max(0).values.nonzero().max() + 3
+ input = input[:, :last].to(self.device)
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ model.record_attention(True)
+ model(BracketedSequence(input))
+ model.train(t)
+ ram = model.retrieve_attention()
+ model.record_attention(False)
+
+ tokens_output = [self.id2token[i.item()] for i in input[0]]
+ tokens_input = ["n/a"] + tokens_output[:-1]
+ for n_head in range(ram[0].size(1)):
+ filename = os.path.join(
+ result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+ )
+ attention_matrices = [m[0, n_head] for m in ram]
+ save_attention_image(
+ filename,
+ tokens_input,
+ tokens_output,
+ attention_matrices,
+ k_top=10,
+ # min_total_attention=0.9,
+ token_gap=12,
+ layer_gap=50,
+ )
+ logger(f"wrote {filename}")
+
+
+######################################################################
+
+
+import expr
+
+
+class Expr(Task):
+ def tensorize(self, sequences):
+ len_max = max([len(x) for x in sequences])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s + "#" * (len_max - len(s))]
+ for s in sequences
+ ]
+ )
+ ],
+ 0,
+ ).to(self.device)
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ nb_variables,
+ sequence_length,
+ operand_max,
+ result_max,
+ batch_size,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.device = device
+
+ train_sequences = expr.generate_sequences(
+ nb_train_samples,
+ nb_variables=nb_variables,
+ length=sequence_length,
+ operand_max=operand_max,
+ result_max=result_max,
+ )
+
+ test_sequences = expr.generate_sequences(
+ nb_test_samples,
+ nb_variables=nb_variables,
+ length=sequence_length,
+ operand_max=operand_max,
+ result_max=result_max,
+ )
+
+ symbols = list(set("#" + "".join(train_sequences + test_sequences)))
+ symbols.sort()
+
+ self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ self.filler, self.space = self.char2id["#"], self.char2id[" "]
+
+ self.train_input = self.tensorize(train_sequences)
+ self.test_input = self.tensorize(test_sequences)
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ last = (batch != self.filler).max(0).values.nonzero().max() + 3
+ batch = batch[:, :last]
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def seq2str(self, s):
+ return "".join([self.id2char[k.item()] for k in s])
+
+ def produce_results(
+ self,
+ n_epoch,
+ model,
+ result_dir,
+ logger,
+ deterministic_synthesis,
+ input_file=None,
+ ):
+ def compute_nb_correct(input):
+ result = input.clone()
+ s = (result == self.space).long()
+ ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.filler
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ nb_total = input.size(0)
+ nb_correct = (input == result).long().min(1).values.sum()
+
+ #######################################################################
+ # Comput predicted vs. true variable values
+
+ nb_delta = torch.zeros(5, dtype=torch.int64)
+ nb_missed = 0
+
+ values_input = expr.extract_results([self.seq2str(s) for s in input])
+ values_result = expr.extract_results([self.seq2str(s) for s in result])
+
+ filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
+
+ with open(filename, "w") as f:
+ for i, r in zip(values_input, values_result):
+ for n, vi in i.items():
+ vr = r.get(n)
+ f.write(f"{vi} {-1 if vr is None else vr}\n")
+
+ if vr is None or vr < 0:
+ nb_missed += 1
+ else:
+ d = abs(vr - vi)
+ if d >= nb_delta.size(0):
+ nb_missed += 1
+ else:
+ nb_delta[d] += 1
+
+ ######################################################################
+
+ return nb_total, nb_correct, nb_delta, nb_missed
+
+ (
+ test_nb_total,
+ test_nb_correct,
+ test_nb_delta,
+ test_nb_missed,
+ ) = compute_nb_correct(self.test_input[:10000])
+
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+ nb_total = test_nb_delta.sum() + test_nb_missed
+ for d in range(test_nb_delta.size(0)):
+ logger(
+ f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
+ )
+ logger(
+ f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
+ )
+
+ ##############################################################
+ # Log a few generated sequences
+ if input_file is None:
+ input = self.test_input[:10]
+ else:
+ with open(input_file, "r") as f:
+ sequences = [e.strip() for e in f.readlines()]
+ sequences = [s + " " + "#" * 50 for s in sequences]
+ input = self.tensorize(sequences)
+
+ result = input.clone()
+ s = (result == self.space).long()
+ ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.filler
+
+ for n in range(result.size(0)):
+ logger(f"test_before {self.seq2str(result[n])}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ correct = (1 - ar_mask) * self.space + ar_mask * input
+ for n in range(result.size(0)):
+ comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
+ logger(f"test_after {self.seq2str(result[n])} {comment}")
+ logger(f"truth {self.seq2str(correct[n])}")
+ ##############################################################
+
+
+######################################################################
+
+import grid
+
+
+class Grid(Task):
+ # Make a tensor from a list of strings
+ def str2tensor(self, descr):
+ token_descr = [s.strip().split(" ") for s in descr]
+ l = max([len(s) for s in token_descr])
+ token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
+ id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+ return torch.tensor(id_descr, device=self.device)
+
+ # Make a list of strings from a tensor
+ def tensor2str(self, x):
+ return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+ # trim all the tensors in the tuple z to remove as much token from
+ # left and right in the first tensor. If z is a tuple, all its
+ # elements are trimed according to the triming for the first
+ def trim(self, z, token="#"):
+ n = self.token2id[token]
+ if type(z) == tuple:
+ x = z[0]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return tuple([t[:, a:b] for t in z])
+ else:
+ i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return z[:, a:b]
+
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ size,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_size = batch_size
+ self.grid_factory = grid.GridFactory(size=size)
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ self.train_descr = self.grid_factory.generate_samples(
+ nb_train_samples, lambda r: tqdm.tqdm(r)
+ )
+ self.test_descr = self.grid_factory.generate_samples(
+ nb_test_samples, lambda r: tqdm.tqdm(r)
+ )
+
+ # Build the tokenizer
+ tokens = set()
+ for d in [self.train_descr, self.test_descr]:
+ for s in d:
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ # make this set a sorted list to get the same tensors given
+ # the same descr
+ tokens = list(tokens)
+ tokens.sort()
+ tokens = ["#"] + tokens
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+ self.t_nul = self.token2id["#"]
+ self.t_true = self.token2id["true"]
+ self.t_false = self.token2id["false"]
+
+ # Tokenize the train and test sets
+ self.train_input = self.str2tensor(self.train_descr)
+ self.test_input = self.str2tensor(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield self.trim(batch)
+
+ def vocabulary_size(self):
+ return len(self.token2id)
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ correct = self.test_input[:1000]
+ result = correct.clone()
+ ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+ result *= 1 - ar_mask # paraaaaanoiaaaaaaa
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_before {e}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_after {e}")
+
+ logger(f"----------------------------------------------------------")
+
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
+######################################################################
+
+import qmlp
+
+
+class QMLP(Task):
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ result_dir,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_size = batch_size
+ self.nb_samples_per_mlp = 256
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
+ nb_mlps=nb_train_samples + nb_test_samples,
+ nb_samples=self.nb_samples_per_mlp,
+ device=self.device,
+ batch_size=64,
+ nb_epochs=250,
+ nb_mlps_per_batch=1024,
+ )
+
+ self.train_input = seq[:nb_train_samples]
+ self.train_q_test_set = q_test_set[:nb_train_samples]
+ self.train_ref_test_errors = test_error[:nb_train_samples]
+ self.test_input = seq[nb_train_samples:]
+ self.test_q_test_set = q_test_set[nb_train_samples:]
+ self.test_ref_test_errors = test_error[nb_train_samples:]
+
+ filename = os.path.join(result_dir, f"train_errors_ref.dat")
+ with open(filename, "w") as f:
+ for e in self.train_ref_test_errors:
+ f.write(f"{e}\n")
+
+ filename = os.path.join(result_dir, f"test_errors_ref.dat")
+ with open(filename, "w") as f:
+ for e in self.test_ref_test_errors:
+ f.write(f"{e}\n")
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ correct = self.test_input[:1000]
+ result = correct.clone()
+ ar_mask = (
+ torch.arange(result.size(1), device=result.device)
+ > self.nb_samples_per_mlp * 3 + 1
+ ).long()[None, :]
+ ar_mask = ar_mask.expand_as(result)
+ result *= 1 - ar_mask # paraaaaanoiaaaaaaa
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ q_train_set = result[:, : self.nb_samples_per_mlp * 3]
+ q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
+ error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
+
+ filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
+ with open(filename, "w") as f:
+ for e in error_test:
+ f.write(f"{e}\n")
+
+
+######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+import cairo
+
+######################################################################
+
+
+class Box:
+ nb_rgb_levels = 10
+
+ def __init__(self, x, y, w, h, r, g, b):
+ self.x = x
+ self.y = y
+ self.w = w
+ self.h = h
+ self.r = r
+ self.g = g
+ self.b = b
+
+ def collision(self, scene):
+ for c in scene:
+ if (
+ self is not c
+ and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
+ and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
+ ):
+ return True
+ return False
+
+
+######################################################################
+
+
+class Normalizer(nn.Module):
+ def __init__(self, mu, std):
+ super().__init__()
+ self.register_buffer("mu", mu)
+ self.register_buffer("log_var", 2 * torch.log(std))
+
+ def forward(self, x):
+ return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
+
+class SignSTE(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ # torch.sign() takes three values
+ s = (x >= 0).float() * 2 - 1
+
+ if self.training:
+ u = torch.tanh(x)
+ return s + u - u.detach()
+ else:
+ return s
+
+
+class DiscreteSampler2d(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ s = (x >= x.max(-3, keepdim=True).values).float()
+
+ if self.training:
+ u = x.softmax(dim=-3)
+ return s + u - u.detach()
+ else:
+ return s
+
+
+def loss_H(binary_logits, h_threshold=1):
+ p = binary_logits.sigmoid().mean(0)
+ h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+ h.clamp_(max=h_threshold)
+ return h_threshold - h.mean()
+
+
+def train_encoder(
+ train_input,
+ test_input,
+ depth,
+ nb_bits_per_token,
+ dim_hidden=48,
+ lambda_entropy=0.0,
+ lr_start=1e-3,
+ lr_end=1e-4,
+ nb_epochs=10,
+ batch_size=25,
+ logger=None,
+ device=torch.device("cpu"),
+):
+ mu, std = train_input.float().mean(), train_input.float().std()
+
+ def encoder_core(depth, dim):
+ l = [
+ [
+ nn.Conv2d(
+ dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+ ),
+ nn.ReLU(),
+ nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
+ nn.ReLU(),
+ ]
+ for k in range(depth)
+ ]
+
+ return nn.Sequential(*[x for m in l for x in m])
+
+ def decoder_core(depth, dim):
+ l = [
+ [
+ nn.ConvTranspose2d(
+ dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
+ ),
+ nn.ReLU(),
+ nn.ConvTranspose2d(
+ dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+ ),
+ nn.ReLU(),
+ ]
+ for k in range(depth - 1, -1, -1)
+ ]
+
+ return nn.Sequential(*[x for m in l for x in m])
+
+ encoder = nn.Sequential(
+ Normalizer(mu, std),
+ nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
+ nn.ReLU(),
+ # 64x64
+ encoder_core(depth=depth, dim=dim_hidden),
+ # 8x8
+ nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
+ )
+
+ quantizer = SignSTE()
+
+ decoder = nn.Sequential(
+ nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
+ # 8x8
+ decoder_core(depth=depth, dim=dim_hidden),
+ # 64x64
+ nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
+ )
+
+ model = nn.Sequential(encoder, decoder)
+
+ nb_parameters = sum(p.numel() for p in model.parameters())
+
+ logger(f"vqae nb_parameters {nb_parameters}")
+
+ model.to(device)
+
+ for k in range(nb_epochs):
+ lr = math.exp(
+ math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+ )
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+
+ acc_train_loss = 0.0
+
+ for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
+ input = input.to(device)
+ z = encoder(input)
+ zq = quantizer(z)
+ output = decoder(zq)
+
+ output = output.reshape(
+ output.size(0), -1, 3, output.size(2), output.size(3)
+ )
+
+ train_loss = F.cross_entropy(output, input)
+
+ if lambda_entropy > 0:
+ train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
+ acc_train_loss += train_loss.item() * input.size(0)
+
+ optimizer.zero_grad()
+ train_loss.backward()
+ optimizer.step()
+
+ acc_test_loss = 0.0
+
+ for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
+ input = input.to(device)
+ z = encoder(input)
+ zq = quantizer(z)
+ output = decoder(zq)
+
+ output = output.reshape(
+ output.size(0), -1, 3, output.size(2), output.size(3)
+ )
+
+ test_loss = F.cross_entropy(output, input)
+
+ acc_test_loss += test_loss.item() * input.size(0)
+
+ train_loss = acc_train_loss / train_input.size(0)
+ test_loss = acc_test_loss / test_input.size(0)
+
+ logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+ sys.stdout.flush()
+
+ return encoder, quantizer, decoder
+
+
+######################################################################
+
+
+def scene2tensor(xh, yh, scene, size):
+ width, height = size, size
+ pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
+ data = pixel_map.numpy()
+ surface = cairo.ImageSurface.create_for_data(
+ data, cairo.FORMAT_ARGB32, width, height
+ )
+
+ ctx = cairo.Context(surface)
+ ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
+
+ for b in scene:
+ ctx.move_to(b.x * size, b.y * size)
+ ctx.rel_line_to(b.w * size, 0)
+ ctx.rel_line_to(0, b.h * size)
+ ctx.rel_line_to(-b.w * size, 0)
+ ctx.close_path()
+ ctx.set_source_rgba(
+ b.r / (Box.nb_rgb_levels - 1),
+ b.g / (Box.nb_rgb_levels - 1),
+ b.b / (Box.nb_rgb_levels - 1),
+ 1.0,
+ )
+ ctx.fill()
+
+ hs = size * 0.1
+ ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
+ ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
+ ctx.rel_line_to(hs, 0)
+ ctx.rel_line_to(0, hs)
+ ctx.rel_line_to(-hs, 0)
+ ctx.close_path()
+ ctx.fill()
+
+ return (
+ pixel_map[None, :, :, :3]
+ .flip(-1)
+ .permute(0, 3, 1, 2)
+ .long()
+ .mul(Box.nb_rgb_levels)
+ .floor_divide(256)
+ )
+
+
+def random_scene(nb_insert_attempts=3):
+ scene = []
+ colors = [
+ ((Box.nb_rgb_levels - 1), 0, 0),
+ (0, (Box.nb_rgb_levels - 1), 0),
+ (0, 0, (Box.nb_rgb_levels - 1)),
+ ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
+ (
+ (Box.nb_rgb_levels * 2) // 3,
+ (Box.nb_rgb_levels * 2) // 3,
+ (Box.nb_rgb_levels * 2) // 3,
+ ),
+ ]
+
+ for k in range(nb_insert_attempts):
+ wh = torch.rand(2) * 0.2 + 0.2
+ xy = torch.rand(2) * (1 - wh)
+ c = colors[torch.randint(len(colors), (1,))]
+ b = Box(
+ xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
+ )
+ if not b.collision(scene):
+ scene.append(b)
+
+ return scene
+
+
+def generate_episode(steps, size=64):
+ delta = 0.1
+ effects = [
+ (False, 0, 0),
+ (False, delta, 0),
+ (False, 0, delta),
+ (False, -delta, 0),
+ (False, 0, -delta),
+ (True, delta, 0),
+ (True, 0, delta),
+ (True, -delta, 0),
+ (True, 0, -delta),
+ ]
+
+ while True:
+ frames = []
+
+ scene = random_scene()
+ xh, yh = tuple(x.item() for x in torch.rand(2))
+
+ actions = torch.randint(len(effects), (len(steps),))
+ nb_changes = 0
+
+ for s, a in zip(steps, actions):
+ if s:
+ frames.append(scene2tensor(xh, yh, scene, size=size))
+
+ grasp, dx, dy = effects[a]
+
+ if grasp:
+ for b in scene:
+ if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
+ x, y = b.x, b.y
+ b.x += dx
+ b.y += dy
+ if (
+ b.x < 0
+ or b.y < 0
+ or b.x + b.w > 1
+ or b.y + b.h > 1
+ or b.collision(scene)
+ ):
+ b.x, b.y = x, y
+ else:
+ xh += dx
+ yh += dy
+ nb_changes += 1
+ else:
+ x, y = xh, yh
+ xh += dx
+ yh += dy
+ if xh < 0 or xh > 1 or yh < 0 or yh > 1:
+ xh, yh = x, y
+
+ if nb_changes > len(steps) // 3:
+ break
+
+ return frames, actions
+
+
+######################################################################
+
+
+def generate_episodes(nb, steps):
+ all_frames, all_actions = [], []
+ for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
+ frames, actions = generate_episode(steps)
+ all_frames += frames
+ all_actions += [actions[None, :]]
+ return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
+
+
+def create_data_and_processors(
+ nb_train_samples,
+ nb_test_samples,
+ mode,
+ nb_steps,
+ depth=3,
+ nb_bits_per_token=8,
+ nb_epochs=10,
+ device=torch.device("cpu"),
+ device_storage=torch.device("cpu"),
+ logger=None,
+):
+ assert mode in ["first_last"]
+
+ if mode == "first_last":
+ steps = [True] + [False] * (nb_steps + 1) + [True]
+
+ if logger is None:
+ logger = lambda s: print(s)
+
+ train_input, train_actions = generate_episodes(nb_train_samples, steps)
+ train_input, train_actions = train_input.to(device_storage), train_actions.to(
+ device_storage
+ )
+ test_input, test_actions = generate_episodes(nb_test_samples, steps)
+ test_input, test_actions = test_input.to(device_storage), test_actions.to(
+ device_storage
+ )
+
+ encoder, quantizer, decoder = train_encoder(
+ train_input,
+ test_input,
+ depth=depth,
+ nb_bits_per_token=nb_bits_per_token,
+ lambda_entropy=1.0,
+ nb_epochs=nb_epochs,
+ logger=logger,
+ device=device,
+ )
+ encoder.train(False)
+ quantizer.train(False)
+ decoder.train(False)
+
+ z = encoder(train_input[:1].to(device))
+ pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
+ z_h, z_w = z.size(2), z.size(3)
+
+ logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
+
+ def frame2seq(input, batch_size=25):
+ seq = []
+ p = pow2.to(device)
+ for x in input.split(batch_size):
+ x = x.to(device)
+ z = encoder(x)
+ ze_bool = (quantizer(z) >= 0).long()
+ output = (
+ ze_bool.permute(0, 2, 3, 1).reshape(
+ ze_bool.size(0), -1, ze_bool.size(1)
+ )
+ * p
+ ).sum(-1)
+
+ seq.append(output)
+
+ return torch.cat(seq, dim=0)
+
+ def seq2frame(input, batch_size=25, T=1e-2):
+ frames = []
+ p = pow2.to(device)
+ for seq in input.split(batch_size):
+ seq = seq.to(device)
+ zd_bool = (seq[:, :, None] // p) % 2
+ zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
+ logits = decoder(zd_bool * 2.0 - 1.0)
+ logits = logits.reshape(
+ logits.size(0), -1, 3, logits.size(2), logits.size(3)
+ ).permute(0, 2, 3, 4, 1)
+ output = torch.distributions.categorical.Categorical(
+ logits=logits / T
+ ).sample()
+
+ frames.append(output)
+
+ return torch.cat(frames, dim=0)
+
+ return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
+
+
+######################################################################
+
+if __name__ == "__main__":
+ (
+ train_input,
+ train_actions,
+ test_input,
+ test_actions,
+ frame2seq,
+ seq2frame,
+ ) = create_data_and_processors(
+ 25000,
+ 1000,
+ nb_epochs=5,
+ mode="first_last",
+ nb_steps=20,
+ )
+
+ input = test_input[:256]
+
+ seq = frame2seq(input)
+ output = seq2frame(seq)
+
+ torchvision.utils.save_image(
+ input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
+ )
+
+ torchvision.utils.save_image(
+ output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
+ )