From 4395f9a90218819997c706de9505cda1c86ad507 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jan 2024 10:56:07 +0100 Subject: [PATCH 1/1] Initial commit --- expr.py | 105 ++++ ffutils.py | 108 ++++ graph.py | 185 ++++++ grid.py | 236 ++++++++ main.py | 912 ++++++++++++++++++++++++++++ maze.py | 309 ++++++++++ memload.py | 84 +++ mygpt.py | 954 +++++++++++++++++++++++++++++ picoclvr.py | 370 ++++++++++++ problems.py | 490 +++++++++++++++ pscan.py | 139 +++++ qmlp.py | 378 ++++++++++++ rpl.py | 177 ++++++ snake.py | 132 ++++ stack.py | 107 ++++ tasks.py | 1663 +++++++++++++++++++++++++++++++++++++++++++++++++++ world.py | 485 +++++++++++++++ 17 files changed, 6834 insertions(+) create mode 100755 expr.py create mode 100755 ffutils.py create mode 100755 graph.py create mode 100755 grid.py create mode 100755 main.py create mode 100755 maze.py create mode 100755 memload.py create mode 100755 mygpt.py create mode 100755 picoclvr.py create mode 100755 problems.py create mode 100755 pscan.py create mode 100755 qmlp.py create mode 100755 rpl.py create mode 100755 snake.py create mode 100755 stack.py create mode 100755 tasks.py create mode 100755 world.py diff --git a/expr.py b/expr.py new file mode 100755 index 0000000..685efd3 --- /dev/null +++ b/expr.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, re + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + + +def random_var(nb_variables=None, variables=None): + if variables is None: + return chr(ord("A") + torch.randint(nb_variables, (1,)).item()) + else: + l = list(variables) + return l[torch.randint(len(l), (1,)).item()] + + +def random_expr(variables, operand_max, budget): + if budget <= 5: + op = torch.randint(2, (1,)).item() + if op == 0 and len(variables) > 0: + return random_var(variables=variables) + else: + return str(torch.randint(operand_max + 1, (1,)).item()) + else: + op = torch.randint(3, (1,)).item() + if op == 0: + e = random_expr(variables, operand_max, budget - 2) + if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"): + return "(" + e + ")" + else: + return e + else: + b = 2 + torch.randint(budget - 5, (1,)).item() + e1 = random_expr(variables, operand_max, b) + e2 = random_expr(variables, operand_max, budget - b - 1) + if op == 1: + return e1 + "+" + e2 + elif op == 2: + return e1 + "*" + e2 + + +def generate_program(nb_variables, operand_max, length): + s = "" + variables = set() + + while len(s) < length: + v = random_var(nb_variables=nb_variables) + s += v + "=" + random_expr(variables, operand_max, budget=20) + ";" + variables.add(v) + + return s, variables + + +def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99): + assert nb_variables <= 26 + sequences = [] + + for n in range(nb): + # We take length itself half of the time, and uniform between + # 1 and length otherwise. The actual length can be slightly + # greater + + l = min(length, 1 + torch.randint(length * 2, (1,)).item()) + result = None + while result == None or max(result.values()) > result_max: + p, v = generate_program(nb_variables, operand_max, l) + v = ", ".join(['"' + v + '": ' + v for v in v]) + ldict = {} + exec(p + "result={" + v + "}", globals(), ldict) + result = ldict["result"] + + k = list(result.keys()) + k.sort() + sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k])) + + return sequences + + +def extract_results(seq): + f = lambda a: (a[0], -1 if a[1] == "" else int(a[1])) + results = [ + dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)]) + for s in seq + ] + return results + + +if __name__ == "__main__": + import time + + start_time = time.perf_counter() + sequences = generate_sequences(1000, length=40) + end_time = time.perf_counter() + for s in sequences[:10]: + print(s) + print(f"{len(sequences) / (end_time - start_time):.02f} samples per second") + + print(extract_results(sequences[:10])) diff --git a/ffutils.py b/ffutils.py new file mode 100755 index 0000000..23952e5 --- /dev/null +++ b/ffutils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch +import sys, contextlib + +import torch +from torch import Tensor + +###################################################################### + + +@contextlib.contextmanager +def evaluation(*models): + with torch.inference_mode(): + t = [(m, m.training) for m in models] + for m in models: + m.train(False) + yield + for m, u in t: + m.train(u) + + +###################################################################### + +from torch.utils._python_dispatch import TorchDispatchMode + + +def hasNaN(x): + if torch.is_tensor(x): + return x.numel() > 0 and x.isnan().max() + else: + try: + return any([hasNaN(y) for y in x]) + except TypeError: + return False + + +class NaNDetect(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + + if hasNaN(res): + raise RuntimeError( + f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" + ) + return res + + +###################################################################### + + +def exception_hook(exc_type, exc_value, tb): + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. + + """ + + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" + + while tb: + print("--------------------------------------------------\n") + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + print(f' File "{filename}", line {line_no}, in {name}') + print(open(filename, "r").readlines()[line_no - 1]) + + if exc_type in {RuntimeError, ValueError, IndexError, TypeError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") + + print() + tb = tb.tb_next + + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") + + +def activate_tensorstack(): + sys.excepthook = exception_hook + + +###################################################################### + +if __name__ == "__main__": + import torch + + def dummy(a, b): + print(a @ b) + + def blah(a, b): + c = b + b + dummy(a, c) + + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm) diff --git a/graph.py b/graph.py new file mode 100755 index 0000000..07e376a --- /dev/null +++ b/graph.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +import cairo + + +###################################################################### + + +def save_attention_image( + # image to save + filename, + tokens_input, + tokens_output, + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, + # do not draw links with a lesser attention + min_link_attention=0, + # draw only the strongest links necessary so that their summed + # attention is above min_total_attention + min_total_attention=None, + # draw only the top k links + k_top=None, + # the purely graphical settings + curved=True, + pixel_scale=8, + token_gap=15, + layer_gap=25, + y_eps=0.5, + padding=10, +): + if k_top is not None: + am = [] + for m in attention_matrices: + am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) + attention_matrices = am + + if min_total_attention is not None: + am = [] + for m in attention_matrices: + s = m.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() + b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) + am.append(m * b) + + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) + + ctx = cairo.Context(surface) + ctx.scale(pixel_scale, pixel_scale) + + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.set_font_size(4.0) + # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + + x, y = 0, 0 + + ctx.set_line_width(0.25) + for d in range(len(attention_matrices)): + at = attention_matrices[d].to("cpu") + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_link_attention: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ax, ay = j * token_gap, y - y_eps + ctx.move_to(ax, ay) + dx, dy = i * token_gap, y - layer_gap + y_eps + if curved: + bx, by = ax, ay - layer_gap * 0.5 + cx, cy = dx, dy + layer_gap * 0.5 + ctx.curve_to(bx, by, cx, cy, dx, dy) + else: + ctx.line_to(dx, dy) + ctx.stroke() + y -= layer_gap + + for d in range(0, len(attention_matrices) + 1): + n = ( + attention_matrices[0].size(-1) + if d == 0 + else attention_matrices[d - 1].size(-2) + ) + for n in range(n): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to( + k * token_gap - width_t / 2, + -token_gap / 5 - len(attention_matrices) * layer_gap, + ) + ctx.show_text(s) + + x, y, width, height = surface.ink_extents() + x -= padding + y -= padding + width += 2 * padding + height += 2 * padding + pdf_surface = cairo.PDFSurface(filename, width, height) + ctx_pdf = cairo.Context(pdf_surface) + ctx_pdf.set_source_surface(surface, -x, -y) + ctx_pdf.paint() + pdf_surface.finish() + + +###################################################################### + +if __name__ == "__main__": + import mygpt + + tokens_output = ["", "-", 3, 4, ""] + tokens_input = [""] + tokens_output[:-1] + + vocabulary_size = 3 + x = torch.randint(vocabulary_size, (1, len(tokens_input))) + + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=4, + dim_keys=2, + dim_hidden=2, + nb_heads=2, + nb_blocks=5, + dropout=0.1, + causal=True, + ) + + model.eval() + model.record_attention() + + y1 = model(mygpt.BracketedSequence(x)).x + + attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] + + save_attention_image( + "attention.pdf", + tokens_input, + tokens_output, + attention_matrices, + # k_top=2, + min_total_attention=0.9, + ) diff --git a/grid.py b/grid.py new file mode 100755 index 0000000..268f4ee --- /dev/null +++ b/grid.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math +import torch, torchvision +import torch.nn.functional as F + +name_shapes = ["A", "B", "C", "D", "E", "F"] + +name_colors = ["red", "yellow", "blue", "green", "white", "purple"] + +###################################################################### + + +class GridFactory: + def __init__( + self, + size=6, + max_nb_items=4, + max_nb_transformations=3, + nb_questions=4, + ): + assert size % 2 == 0 + self.size = size + self.max_nb_items = max_nb_items + self.max_nb_transformations = max_nb_transformations + self.nb_questions = nb_questions + + def generate_scene(self): + nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 + col = torch.full((self.size * self.size,), -1) + shp = torch.full((self.size * self.size,), -1) + a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] + col[:nb_items] = a % len(name_colors) + shp[:nb_items] = a // len(name_colors) + i = torch.randperm(self.size * self.size) + col = col[i] + shp = shp[i] + return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) + + def random_transformations(self, scene): + col, shp = scene + + descriptions = [] + nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() + transformations = torch.randint(5, (nb_transformations,)) + + for t in transformations: + if t == 0: + col, shp = col.flip(0), shp.flip(0) + descriptions += [" vertical flip"] + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + descriptions += [" horizontal flip"] + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + descriptions += [" rotate 90 degrees"] + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + descriptions += [" rotate 180 degrees"] + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + descriptions += [" rotate 270 degrees"] + + col, shp = col.contiguous(), shp.contiguous() + + return (col, shp), descriptions + + def print_scene(self, scene): + col, shp = scene + + # for i in range(self.size): + # for j in range(self.size): + # if col[i,j] >= 0: + # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + + for i in range(self.size): + for j in range(self.size): + if col[i, j] >= 0: + print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + elif j == 0: + print(" +", end="") + else: + print("-+", end="") + if j < self.size - 1: + print("--", end="") + else: + print("") + if i < self.size - 1: + for j in range(self.size - 1): + print(" | ", end="") + print(" |") + + def grid_positions(self, scene): + col, shp = scene + + properties = [] + + for i in range(self.size): + for j in range(self.size): + if col[i, j] >= 0: + n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + properties += [f"a {n} at {i} {j}"] + + return properties + + def all_properties(self, scene): + col, shp = scene + + properties = [] + + for i1 in range(self.size): + for j1 in range(self.size): + if col[i1, j1] >= 0: + n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + properties += [f"there is a {n1}"] + if i1 < self.size // 2: + properties += [f"a {n1} is in the top half"] + if i1 >= self.size // 2: + properties += [f"a {n1} is in the bottom half"] + if j1 < self.size // 2: + properties += [f"a {n1} is in the left half"] + if j1 >= self.size // 2: + properties += [f"a {n1} is in the right half"] + for i2 in range(self.size): + for j2 in range(self.size): + if col[i2, j2] >= 0: + n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + if i1 > i2: + properties += [f"a {n1} is below a {n2}"] + if i1 < i2: + properties += [f"a {n1} is above a {n2}"] + if j1 > j2: + properties += [f"a {n1} is right of a {n2}"] + if j1 < j2: + properties += [f"a {n1} is left of a {n2}"] + if abs(i1 - i2) + abs(j1 - j2) == 1: + properties += [f"a {n1} is next to a {n2}"] + + return properties + + def generate_scene_and_questions(self): + while True: + while True: + start_scene = self.generate_scene() + scene, transformations = self.random_transformations(start_scene) + true = self.all_properties(scene) + if len(true) >= self.nb_questions: + break + + for a in range(10): + col, shp = scene + col, shp = col.view(-1), shp.view(-1) + p = torch.randperm(col.size(0)) + col, shp = col[p], shp[p] + other_scene = ( + col.view(self.size, self.size), + shp.view(self.size, self.size), + ) + + false = self.all_properties(other_scene) + + # We sometime add properties from a totally different + # scene to have negative "there is a xxx xxx" + # properties + if torch.rand(1).item() < 0.2: + other_scene = self.generate_scene() + false += self.all_properties(other_scene) + + false = list(set(false) - set(true)) + if len(false) >= self.nb_questions: + break + + if a < 10: + break + + true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] + false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] + true = [" " + q + " true" for q in true] + false = [" " + q + " false" for q in false] + + union = true + false + questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] + + result = " ".join( + [" " + x for x in self.grid_positions(start_scene)] + + transformations + + questions + ) + + return start_scene, scene, result + + def generate_samples(self, nb, progress_bar=None): + result = [] + + r = range(nb) + if progress_bar is not None: + r = progress_bar(r) + + for _ in r: + result.append(self.generate_scene_and_questions()[2]) + + return result + + +###################################################################### + +if __name__ == "__main__": + import time + + grid_factory = GridFactory() + + # start_time = time.perf_counter() + # samples = grid_factory.generate_samples(10000) + # end_time = time.perf_counter() + # print(f"{len(samples) / (end_time - start_time):.02f} samples per second") + + start_scene, scene, questions = grid_factory.generate_scene_and_questions() + print() + print("-- Original scene -----------------------------") + print() + grid_factory.print_scene(start_scene) + print() + print("-- Transformed scene --------------------------") + print() + grid_factory.print_scene(scene) + print() + print("-- Sequence -----------------------------------") + print() + print(questions) + +###################################################################### diff --git a/main.py b/main.py new file mode 100755 index 0000000..df46652 --- /dev/null +++ b/main.py @@ -0,0 +1,912 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, argparse, time, tqdm, os, datetime, warnings + +import torch, torchvision +from torch import nn +from torch.nn import functional as F + +import ffutils +import mygpt, tasks, problems + +###################################################################### + +if torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + +parser = argparse.ArgumentParser( + description="An implementation of GPT with cache.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) + +parser.add_argument( + "--task", + type=str, + default="twotargets", + help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp", +) + +parser.add_argument("--log_filename", type=str, default="train.log", help=" ") + +parser.add_argument("--result_dir", type=str, default=None) + +parser.add_argument("--seed", type=int, default=0) + +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + +parser.add_argument("--nb_epochs", type=int, default=50) + +parser.add_argument("--batch_size", type=int, default=None) + +parser.add_argument("--nb_train_samples", type=int, default=None) + +parser.add_argument("--nb_test_samples", type=int, default=None) + +parser.add_argument("--optim", type=str, default="adam") + +######################################## + +parser.add_argument("--nb_warmup_iter", type=int, default=100) + +parser.add_argument("--nb_decay_iter", type=int, default=5000) + +parser.add_argument("--learning_rate", type=float, default=6e-4) + +parser.add_argument("--min_learning_rate", type=float, default=6e-5) + +######################################## + +parser.add_argument("--model", type=str, default=None) + +parser.add_argument("--attention", type=str, default=None) + +parser.add_argument("--dim_model", type=int, default=None) + +parser.add_argument("--dim_keys", type=int, default=None) + +parser.add_argument("--dim_hidden", type=int, default=None) + +parser.add_argument("--nb_heads", type=int, default=None) + +parser.add_argument("--nb_lines", type=int, default=None) + +parser.add_argument("--caterpillar_height", type=int, default=None) + +parser.add_argument("--rho", type=float, default=0.0) + +parser.add_argument("--dim_rec_v", type=int, default=None) + +parser.add_argument("--nb_blocks", type=int, default=None) + +parser.add_argument("--dropout", type=float, default=0.1) + +######################################## + +parser.add_argument("--deterministic_synthesis", action="store_true", default=False) + +parser.add_argument("--no_checkpoint", action="store_true", default=False) + +parser.add_argument("--overwrite_results", action="store_true", default=False) + +parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") + +############################## +# rpl options + +parser.add_argument("--rpl_nb_starting_values", type=int, default=3) + +parser.add_argument("--rpl_max_input", type=int, default=9) + +parser.add_argument("--rpl_prog_len", type=int, default=8) + +parser.add_argument("--rpl_nb_runs", type=int, default=5) + +parser.add_argument("--rpl_no_prog", action="store_true", default=False) + +############################## +# grid options + +parser.add_argument("--grid_size", type=int, default=6) + +############################## +# picoclvr options + +parser.add_argument("--picoclvr_nb_colors", type=int, default=5) + +parser.add_argument("--picoclvr_height", type=int, default=12) + +parser.add_argument("--picoclvr_width", type=int, default=16) + +parser.add_argument("--picocvlr_prune_properties", type=str, default="none") + +############################## +# Maze options + +parser.add_argument("--maze_height", type=int, default=13) + +parser.add_argument("--maze_width", type=int, default=21) + +parser.add_argument("--maze_nb_walls", type=int, default=15) + +############################## +# Snake options + +parser.add_argument("--snake_height", type=int, default=9) + +parser.add_argument("--snake_width", type=int, default=12) + +parser.add_argument("--snake_nb_colors", type=int, default=5) + +parser.add_argument("--snake_length", type=int, default=200) + +############################## +# Stack options + +parser.add_argument("--stack_nb_steps", type=int, default=100) + +parser.add_argument("--stack_nb_stacks", type=int, default=3) + +parser.add_argument("--stack_nb_digits", type=int, default=3) + +parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75) + +############################## +# Expr options + +parser.add_argument("--expr_nb_variables", type=int, default=5) + +parser.add_argument("--expr_sequence_length", type=int, default=40) + +parser.add_argument("--expr_operand_max", type=int, default=9) + +parser.add_argument("--expr_result_max", type=int, default=99) + +parser.add_argument("--expr_input_file", type=str, default=None) + +############################## +# Memory + +parser.add_argument("--memory_len_total", type=int, default=32) + +############################## +# Mixing + +parser.add_argument("--mixing_hard", action="store_true", default=False) + +parser.add_argument("--mixing_deterministic_start", action="store_true", default=False) + +###################################################################### + +args = parser.parse_args() + +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + +if args.result_dir is None: + args.result_dir = f"results_{args.task}_{args.model}" + +###################################################################### + +default_task_args = { + "addition": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "byheart": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "expr": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 2500000, + "nb_test_samples": 10000, + }, + "grid": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "qmlp": { + "model": "37M", + "batch_size": 10, + "nb_train_samples": 100000, + "nb_test_samples": 1000, + }, + "guessop": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 1000000, + "nb_test_samples": 10000, + }, + "learnop": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "maze": { + "model": "37M", + "batch_size": 5, + "nb_train_samples": 100000, + "nb_test_samples": 10000, + }, + "picoclvr": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "rpl": { + "model": "352M", + "batch_size": 5, + "nb_train_samples": 2500000, + "nb_test_samples": 10000, + }, + "snake": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "stack": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 100000, + "nb_test_samples": 1000, + }, + "twotargets": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "memory": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 25000, + "nb_test_samples": 10000, + }, + "mixing": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "mnist": { + "model": "37M", + "batch_size": 10, + "nb_train_samples": 60000, + "nb_test_samples": 10000, + }, +} + +if args.task in default_task_args: + for k, v in default_task_args[args.task].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + +default_model_args = { + "17K": { + "attention": "mha", + "dim_model": 32, + "dim_keys": 32, + "dim_hidden": 32, + "nb_heads": 2, + "dim_rec_v": 16, + "nb_blocks": 2, + }, + "17K-C": { + "attention": "caterpillar", + "dim_model": 32, + "dim_keys": 32, + "dim_hidden": 32, + "nb_heads": 2, + "nb_lines": 16, + "caterpillar_height": 4, + "dim_rec_v": 16, + "nb_blocks": 2, + }, + "4M": { + "attention": "mha", + "dim_model": 256, + "dim_keys": 32, + "dim_hidden": 1024, + "nb_heads": 4, + "dim_rec_v": 64, + "nb_blocks": 6, + }, + "4M-C": { + "attention": "caterpillar", + "dim_model": 256, + "dim_keys": 32, + "dim_hidden": 1024, + "nb_heads": 4, + "nb_lines": 32, + "caterpillar_height": 4, + "dim_rec_v": 64, # dim_model / nb_heads + "nb_blocks": 6, + }, + "37M": { + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 64, + "nb_blocks": 12, + }, + "37M-C": { + "attention": "caterpillar", + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 256, + "caterpillar_height": 32, + "dim_rec_v": 64, + "nb_blocks": 12, + }, + "122M": { + "attention": "mha", + "dim_model": 768, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 96, + "nb_blocks": 24, + }, + "122M-C": { + "attention": "caterpillar", + "dim_model": 768, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 128, + "dim_rec_v": 96, + "nb_blocks": 24, + }, + "352M": { + "attention": "mha", + "dim_model": 1024, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 128, + "nb_blocks": 48, + }, + "352M-C": { + "attention": "caterpillar", + "dim_model": 1024, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 128, + "dim_rec_v": 128, + "nb_blocks": 48, + }, +} + +if args.model in default_model_args: + for k, v in default_model_args[args.model].items(): + if getattr(args, k) is None: + setattr(args, k, v) +else: + raise ValueError(f"Unknown model {args.model}") + +###################################################################### + +try: + os.mkdir(args.result_dir) +except FileExistsError: + if not args.overwrite_results: + print(f"result directory {args.result_dir} already exists") + exit(1) + +log_file = open(os.path.join(args.result_dir, args.log_filename), "a") + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +with os.popen("sha256sum *.py") as f: + for l in f: + log_string(f"sha256sum {l.strip()}") + +now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") + +log_string(f"argv {' '.join(sys.argv)}") + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + + +###################################################################### + +# from nanoGPT + + +def get_lr(it): + # 1) linear warmup for warmup_iter steps + if it < args.nb_warmup_iter: + return args.learning_rate * it / args.nb_warmup_iter + # 2) if it > nb_decay_iter, return min learning rate + if it > args.nb_decay_iter: + return args.min_learning_rate + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.nb_warmup_iter) / ( + args.nb_decay_iter - args.nb_warmup_iter + ) + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return args.min_learning_rate + coeff * ( + args.learning_rate - args.min_learning_rate + ) + + +###################################################################### + + +def picoclvr_pruner_horizontal_green(p): + return not ("green" in p and ("left" in p or "right" in p)) + + +picoclvr_pruner_train = ( + picoclvr_pruner_horizontal_green + if args.picocvlr_prune_properties in {"train+eval"} + else None +) + +picoclvr_pruner_eval = ( + (lambda p: not picoclvr_pruner_horizontal_green(p)) + if args.picocvlr_prune_properties in {"train+eval", "eval"} + else None +) + +###################################################################### + +device_data = device + +if args.task == "byheart": + task = tasks.SandBox( + problem=problems.ProblemByHeart(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + args.max_percents_of_test_in_train = -1 + +elif args.task == "learnop": + task = tasks.SandBox( + problem=problems.ProblemLearnOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + + +elif args.task == "guessop": + task = tasks.SandBox( + problem=problems.ProblemGuessOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + + +elif args.task == "twotargets": + task = tasks.SandBox( + problem=problems.ProblemTwoTargets(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "memory": + task = tasks.SandBox( + problem=problems.ProblemMemory(len_total=args.memory_len_total), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "mixing": + task = tasks.SandBox( + problem=problems.ProblemMixing( + hard=args.mixing_hard, random_start=not args.mixing_deterministic_start + ), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "addition": + task = tasks.SandBox( + problem=problems.ProblemAddition(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "picoclvr": + task = tasks.PicoCLVR( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.picoclvr_height, + width=args.picoclvr_width, + nb_colors=args.picoclvr_nb_colors, + logger=log_string, + device=device_data, + pruner_train=picoclvr_pruner_train, + pruner_eval=picoclvr_pruner_eval, + ) + +elif args.task == "mnist": + task = tasks.MNIST( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + device=device_data, + ) + +elif args.task == "maze": + task = tasks.Maze( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, + device=device_data, + ) + +elif args.task == "snake": + task = tasks.Snake( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.snake_height, + width=args.snake_width, + nb_colors=args.snake_nb_colors, + length=args.snake_length, + prompt_length=args.snake_length // 2, + device=device_data, + ) + +elif args.task == "stack": + task = tasks.Stack( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + nb_steps=args.stack_nb_steps, + nb_stacks=args.stack_nb_stacks, + nb_digits=args.stack_nb_digits, + fraction_values_for_train=args.stack_fraction_values_for_train, + device=device_data, + ) + +elif args.task == "expr": + task = tasks.Expr( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + nb_variables=args.expr_nb_variables, + sequence_length=args.expr_sequence_length, + operand_max=args.expr_operand_max, + result_max=args.expr_result_max, + batch_size=args.batch_size, + device=device_data, + ) + +elif args.task == "rpl": + task = tasks.RPL( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + nb_starting_values=args.rpl_nb_starting_values, + max_input=args.rpl_max_input, + prog_len=args.rpl_prog_len, + nb_runs=args.rpl_nb_runs, + no_prog=args.rpl_no_prog, + logger=log_string, + device=device_data, + ) + +elif args.task == "grid": + task = tasks.Grid( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + size=args.grid_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "qmlp": + task = tasks.QMLP( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + result_dir=args.result_dir, + logger=log_string, + device=device_data, + ) + +else: + raise ValueError(f"Unknown task {args.task}") + +###################################################################### + +log_string(f"device {device}") + +vocabulary_size = task.vocabulary_size() + +log_string(f"vocabulary_size {vocabulary_size}") + +############################## + +model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_lines=args.nb_lines, + caterpillar_height=args.caterpillar_height, + dim_rec_v=args.dim_rec_v, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, + attention_layer=args.attention, +) + +model.to(device) + +nb_parameters = sum(p.numel() for p in model.parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") + +###################################################################### + +nb_epochs_finished = 0 + +if args.no_checkpoint: + log_string(f"not trying to load checkpoint.") + +else: + try: + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + checkpoint = torch.load(checkpoint_name) + nb_epochs_finished = checkpoint["nb_epochs_finished"] + model.load_state_dict(checkpoint["model_state"]) + torch.set_rng_state(checkpoint["rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) + + log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.") + + except FileNotFoundError: + log_string("starting from scratch.") + + except: + log_string("error when loading the checkpoint.") + exit(1) + +###################################################################### + +if args.task == "expr" and args.expr_input_file is not None: + task.produce_results( + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + input_file=args.expr_input_file, + ) + + exit(0) + +###################################################################### + +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +# Compute the entropy of the training tokens + +token_count = 0 +for input in task.batches(split="train"): + token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +###################################################################### +# A bit of paranoia never hurts + +if args.max_percents_of_test_in_train >= 0: + + def subsets_as_tuples(batches, cs): + s = set() + for batch in batches: + for x in batch: + s.add(tuple([v.item() for v in x])) + if len(s) == cs: + yield s + s = set() + yield s + + nb_test, nb_in_train = 0, 0 + for test_subset in subsets_as_tuples(task.batches(split="test"), 25000): + in_train = set() + for train_subset in subsets_as_tuples(task.batches(split="train"), 25000): + in_train.update(test_subset.intersection(train_subset)) + nb_in_train += len(in_train) + nb_test += len(test_subset) + + log_string( + f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" + ) + + assert ( + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 + ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" + +############################## + +nb_samples_seen = 0 + +if nb_epochs_finished >= nb_epochs: + task.produce_results( + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + ) + +time_pred_result = None + +it = 0 + +for n_epoch in range(nb_epochs_finished, nb_epochs): + if args.optim == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) + elif args.optim == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + elif args.optim == "adamw": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + else: + raise ValueError(f"Unknown optimizer {args.optim}.") + + model.train() + + nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 + + for input in task.batches(split="train"): + model.reset_inner_loss() + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + inner_loss = model.get_inner_loss() + + acc_train_loss += loss.item() * input.size(0) + acc_train_inner_loss += inner_loss.item() * input.size(0) + + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) + + total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + + it += 1 + lr = get_lr(it) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + # log_string(f"learning_rate {lr}") + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + with torch.autograd.no_grad(): + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in task.batches(split="test"): + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + log_string( + f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}" + ) + + task.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + ) + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) + + time_current_result = datetime.datetime.now() + if time_pred_result is not None: + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) + time_pred_result = time_current_result + + checkpoint = { + "nb_epochs_finished": n_epoch + 1, + "model_state": model.state_dict(), + "rng_state": torch.get_rng_state(), + } + + if torch.cuda.is_available(): + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() + + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + torch.save(checkpoint, checkpoint_name) + log_string(f"saved checkpoint {checkpoint_name}") + +###################################################################### diff --git a/maze.py b/maze.py new file mode 100755 index 0000000..8ac9fce --- /dev/null +++ b/maze.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision + +###################################################################### + +v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4 + + +def create_maze(h=11, w=17, nb_walls=8): + assert h % 2 == 1 and w % 2 == 1 + + a, k = 0, 0 + + while k < nb_walls: + while True: + if a == 0: + m = torch.zeros(h, w, dtype=torch.int64) + m[0, :] = 1 + m[-1, :] = 1 + m[:, 0] = 1 + m[:, -1] = 1 + + r = torch.rand(4) + + if r[0] <= 0.5: + i1, i2, j = ( + int((r[1] * h).item()), + int((r[2] * h).item()), + int((r[3] * w).item()), + ) + i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2 + i1, i2 = min(i1, i2), max(i1, i2) + if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1: + m[i1 : i2 + 1, j] = 1 + break + else: + i, j1, j2 = ( + int((r[1] * h).item()), + int((r[2] * w).item()), + int((r[3] * w).item()), + ) + i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2 + j1, j2 = min(j1, j2), max(j1, j2) + if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1: + m[i, j1 : j2 + 1] = 1 + break + a += 1 + + if a > 10 * nb_walls: + a, k = 0, 0 + + k += 1 + + return m + + +###################################################################### + + +def compute_distance(walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + d = ( + torch.cat( + ( + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d) + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + + +###################################################################### + + +def compute_policy(walls, goal_i, goal_j): + distance = compute_distance(walls, goal_i, goal_j) + distance = distance + walls.numel() * walls + + value = distance.new_full((4,) + distance.size(), walls.numel()) + value[0, :, 1:] = distance[:, :-1] # < + value[1, :, :-1] = distance[:, 1:] # > + value[2, 1:, :] = distance[:-1, :] # ^ + value[3, :-1, :] = distance[1:, :] # v + + proba = (value.min(dim=0)[0][None] == value).float() + proba = proba / proba.sum(dim=0)[None] + proba = proba * (1 - walls) + walls.float() / 4 + + return proba + + +def stationary_densities(mazes, policies): + policies = policies * (mazes != v_goal)[:, None] + start = (mazes == v_start).nonzero(as_tuple=True) + probas = mazes.new_zeros(mazes.size(), dtype=torch.float32) + pred_probas = probas.clone() + probas[start] = 1.0 + + while not pred_probas.equal(probas): + pred_probas.copy_(probas) + probas.zero_() + probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :] + probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :] + probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1] + probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:] + probas[start] = 1.0 + + return probas + + +###################################################################### + + +def mark_path(walls, i, j, goal_i, goal_j, policy): + action = torch.distributions.categorical.Categorical( + policy.permute(1, 2, 0) + ).sample() + n, nmax = 0, walls.numel() + while i != goal_i or j != goal_j: + di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]] + i, j = i + di, j + dj + assert walls[i, j] == 0 + walls[i, j] = v_path + n += 1 + assert n < nmax + + +def path_optimality(ref_paths, paths): + return (ref_paths == v_path).long().flatten(1).sum(1) == ( + paths == v_path + ).long().flatten(1).sum(1) + + +def path_correctness(mazes, paths): + still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum( + 1 + ) == 0 + reached = still_ok.new_zeros(still_ok.size()) + current, pred_current = paths.clone(), paths.new_zeros(paths.size()) + goal = (mazes == v_goal).long() + while not pred_current.equal(current): + pred_current.copy_(current) + u = (current == v_start).long() + possible_next = ( + u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0 + ).long() + u = u[:, 1:-1, 1:-1] + reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * ( + (current == v_path).sum((1, 2)) == 0 + ) + current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + ( + v_start - v_path + ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path)) + still_ok *= (current == v_start).sum((1, 2)) <= 1 + + return still_ok * reached + + +###################################################################### + + +def create_maze_data( + nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x +): + mazes = torch.empty(nb, height, width, dtype=torch.int64) + paths = torch.empty(nb, height, width, dtype=torch.int64) + policies = torch.empty(nb, 4, height, width) + + for n in progress_bar(range(nb)): + maze = create_maze(height, width, nb_walls) + i = (maze == v_empty).nonzero() + while True: + start, goal = i[torch.randperm(i.size(0))[:2]] + if (start - goal).abs().sum() >= dist_min: + break + start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1] + + policy = compute_policy(maze, goal_i, goal_j) + path = maze.clone() + mark_path(path, start_i, start_j, goal_i, goal_j, policy) + maze[start_i, start_j] = v_start + maze[goal_i, goal_j] = v_goal + path[start_i, start_j] = v_start + path[goal_i, goal_j] = v_goal + + mazes[n] = maze + paths[n] = path + policies[n] = policy + + return mazes, paths, policies + + +###################################################################### + + +def save_image( + name, + mazes, + target_paths=None, + predicted_paths=None, + path_correct=None, + path_optimal=None, +): + colors = torch.tensor( + [ + [255, 255, 255], # empty + [0, 0, 0], # wall + [0, 255, 0], # start + [127, 127, 255], # goal + [255, 0, 0], # path + ] + ) + + mazes = mazes.cpu() + + c_mazes = ( + colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) + ) + + imgs = c_mazes.unsqueeze(1) + + if target_paths is not None: + target_paths = target_paths.cpu() + + c_target_paths = ( + colors[target_paths.reshape(-1)] + .reshape(target_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + + imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1) + + if predicted_paths is not None: + predicted_paths = predicted_paths.cpu() + c_predicted_paths = ( + colors[predicted_paths.reshape(-1)] + .reshape(predicted_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1) + + img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1) + + # NxKxCxHxW + if path_optimal is not None: + path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1) + img = ( + img * (1 - path_optimal) + + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal + ) + + if path_correct is not None: + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * ( + 1 - path_correct + ) + + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + + print(f"{img.size()=} {imgs.size()=}") + + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) + + +###################################################################### + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mazes, paths, policies = create_maze_data(8) + mazes, paths = mazes.to(device), paths.to(device) + save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths) + print(path_correctness(mazes, paths)) + +###################################################################### diff --git a/memload.py b/memload.py new file mode 100755 index 0000000..5fcd089 --- /dev/null +++ b/memload.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension + +cpp_source = """ +std::vector greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) { + auto nb_lines = load_start.size(1); + auto batch_size = line_requests.size(0); + auto nb_heads = line_requests.size(1); + auto T = line_requests.size(2); + + auto load_start_a = load_start.accessor(); + auto line_requests_a = line_requests.accessor(); + + auto load = torch::empty({batch_size, nb_lines, T}); + auto load_a = load.accessor(); + + auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64)); + auto allocation_result_a = allocation_result.accessor(); + + for(int n = 0; n < batch_size; n++) { + for(int t = 0; t < T; t++) { + for(int l = 0; l < nb_lines; l++) { + if(t == 0) { + load[n][l][t] = decay * load_start_a[n][l]; + } else { + load[n][l][t] = decay * load[n][l][t-1]; + } + } + for(int h = 0; h < nb_heads; h++) { + if(line_requests_a[n][h][t] > 0) { + int l_lowest_load; + for(int l = 0; l < nb_lines; l++) { + if(l == 0 || load_a[n][l][t] + +# This is an implementation from scratch of a "GPT", that is a model +# composed of several causal self-attention blocks. It is equipped +# with a caching mechanism for keys and values to avoid a O(N^3) cost +# for auto-regression. + +import math, warnings + +import torch, einops + +from torch import nn +from torch.nn import functional as F +from functorch.dim import dims + +import ffutils + +# import memload + +###################################################################### + +# A BracketedSequence is a BxTx... tensor with a first and a nb time +# steps to compute. + +# Modules able to process it expect that they will have to process a +# first bracket starting at t=0, followed by a succession of brackets +# that move forward in time, do not overlap, and cover the axis T with +# no holes. +# +# Although it is more general, for a classical prompt-conditioned +# auto-regressive process it will be a first bracket starting at 0 and +# of arbitrary length for the "prompt", followed by brackets of length +# 1 for the successive tokens. +# +# Modules able to process brackets may implement a cache that is +# resetted when the input bracket starts at t=0 + + +class BracketedSequence: + def __init__(self, x, first=None, nb=None, init_cache=None): + self.x = x + assert (first is None and nb is None and init_cache is None) or ( + first is not None and nb is not None and init_cache is not None + ) + + self.first = 0 if first is None else first + self.nb = x.size(1) if nb is None else nb + self.init_cache = True if init_cache is None else init_cache + + def slice(self): + return self.x[:, self.first : self.first + self.nb] + + def complete(self): + return self.first == 0 and self.nb == self.x.size(1) + + +###################################################################### + + +class CacheWrapper(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + if bs.init_cache: + y = self.f(bs.slice()) + self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:])) + self.cache_y[:, bs.first : bs.first + bs.nb] = y + else: + assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2]) + assert bs.first + bs.nb <= self.cache_y.size(1) + self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class AddPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) + + def forward(self, bs): + if bs.init_cache: + t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ + :, None + ] + j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + None, : + ] + k = j % 2 + self.pe = torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + self.cache_y = bs.x.new(bs.x.size()) + + self.cache_y[:, bs.first : bs.first + bs.nb] = ( + bs.slice() + self.pe[bs.first : bs.first + bs.nb] + ) + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +import pscan + + +# X is /.../xTxD A is /.../xT Y_init is /.../xD + + +def pscan_dim(A, X, Y_init, dim=-2): + s = X.size() + a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel() + + A = A.reshape(a, T, *s[dim + 1 : -1]) + X = X.reshape(a, T, *s[dim + 1 : -1], -1) + + if Y_init is None: + Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1)) + else: + Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1) + + Y = pscan.pscan(A, X, Y_init).reshape(s) + + return Y + + +def pscan_shape(A, X, Y_init): + s = X.size() + A = A.reshape(-1, s[-2]) + X = X.reshape(-1, s[-2], s[-1]) + + if Y_init is None: + Y_init = X.new_zeros(X.size(0), s[-1]) + else: + Y_init = Y_init.reshape(-1, s[-1]) + + Y = pscan.pscan(A, X, Y_init).reshape(s) + + return Y + + +def nsum_shape(X, Y_init): + s = X.size() + X = X.reshape(-1, s[-2], s[-1]) # ntd + + Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1]) + result = [] + + for k in range(X.size(1)): + Y = Y + X[:, k] + Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1) + result.append(Y) + + return torch.cat(result, dim=1).reshape(s) + + +############################## + + +class DumbRec(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + nb_lines, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.nb_lines = nb_lines + self.attention_dropout = attention_dropout + + self.k_star = randw(nb_lines, dim_qk) + + self.w_qw = randw(nb_heads, dim_qk, dim_in) + self.w_qr = randw(nb_heads, dim_qk, dim_in) + # self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + warnings.warn("l2 regularization", RuntimeWarning) + return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_qw.device) + + def forward(self, bs): + x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb + + if bs.init_cache: + self.rec_v = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) + ) + # self.rec_k = x_q.new_zeros( + # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) + # ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + + ###################################################################### + # Compute the recurrent state + + qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) + + v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) + # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) + + aw = torch.einsum( + "nhtd,ltd->nhlt", + qw, + k_star, + ) / math.sqrt(self.w_qw.size(1)) + + aw = aw.softmax(dim=2) # nhlt + + if self.train: + self.acc_attention += aw.sum(dim=(0, 1, 3)) + self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) + + aw = F.dropout(aw, self.attention_dropout, self.training) + + A = 1 - aw.sum(dim=1) # nlt + + V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() + # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() + + if t0 == 0: + V0 = None + # K0 = None + else: + V0 = self.rec_v[:, :, t0 - 1] + # K0 = self.rec_k[:, :, t0 - 1] + + self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) + # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) + + ###################################################################### + # compute the readout + + qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr) + + ar = torch.einsum( + "nhtd,ld->nhlt", + qr, + # self.rec_k[:, :, t0:t1], + self.k_star, + ) / math.sqrt(self.w_qr.size(1)) + + ar = ar.softmax(dim=2) # nhlt + + ar = F.dropout(ar, self.attention_dropout, self.training) + + y = torch.einsum( + "nhlt,nltd->nthd", + ar, + self.rec_v[:, :, t0:t1], + ).flatten(2) + + self.cache_y[:, t0:t1] = y @ self.w_o + + return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache) + + +############################## + + +class KVRec(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + nb_lines, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.nb_lines = nb_lines + self.attention_dropout = attention_dropout + + self.k_star = randw(nb_lines, dim_qk) + + self.w_qw = randw(nb_heads, dim_qk, dim_in) + self.w_qr = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + warnings.warn("l2 regularization", RuntimeWarning) + return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_qw.device) + # warnings.warn("side regularization", RuntimeWarning) + # return ( + # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum() + # ) + # return torch.tensor([0], device=self.w_qw.device) + + def forward(self, bs): + x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb + + # n,h,l,t,d = dims(5) + + if bs.init_cache: + self.rec_v = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) + ) + self.rec_k = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) + ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + + ###################################################################### + # Compute the recurrent state + + qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) + + v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) + k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) + + aw = torch.einsum( + "nhtd,ltd->nhlt", + qw, + k_star, + ) / math.sqrt(self.w_qw.size(1)) + + aw = aw.softmax(dim=2) # nhlt + + if self.train: + # We want all the memory lines to be used similarly + self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT + self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) + + aw = F.dropout(aw, self.attention_dropout, self.training) + + A = 1 - aw.sum(dim=1) # nlt + + V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() + K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() + + if t0 == 0: + V0 = None + K0 = None + else: + V0 = self.rec_v[:, :, t0 - 1] + K0 = self.rec_k[:, :, t0 - 1] + + self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) + self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) + + ###################################################################### + # compute the readout + + qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr) + + ar = torch.einsum( + "nhtd,nltd->nhlt", + qr, + self.rec_k[:, :, t0:t1], + ) / math.sqrt(self.w_qr.size(1)) + + ar = ar.softmax(dim=2) # nhlt + + ar = F.dropout(ar, self.attention_dropout, self.training) + + y = torch.einsum( + "nhlt,nltd->nthd", + ar, + self.rec_v[:, :, t0:t1], + ).flatten(2) + + self.cache_y[:, t0:t1] = y @ self.w_o + + return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache) + + +############################## + + +def moving_window(x, dim, win_dim, win_size): + size, stride = x.size(), x.stride() + size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :] + size = size[:win_dim] + (win_size,) + size[win_dim:] + stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:] + + return x.as_strided(size=size, stride=stride) + + +############################## + + +class Caterpillar(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + caterpillar_length, + caterpillar_height, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + warnings.warn("Caterpillar", RuntimeWarning) + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.caterpillar_length = caterpillar_length + self.caterpillar_height = caterpillar_height + self.attention_dropout = attention_dropout + + self.w_G = randw(nb_heads, caterpillar_height, dim_in) + self.b_G = nn.Parameter( + torch.full( + (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) + ) + ) + + self.w_K = randw(nb_heads, dim_qk, dim_in) + self.w_V = randw(nb_heads, dim_v, dim_in) + self.w_Q = randw(nb_heads, dim_qk, dim_in) + self.w_O = randw(dim_v * nb_heads, dim_in) + + self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) + self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + return torch.tensor([0], device=self.w_Q.device) + + def forward(self, bs): + # Dimensions to make the source a bit clearer, that's needed + + X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb + + N = bs.x.size(0) + T = bs.x.size(1) + DV = self.w_V.size(1) + DK = self.w_K.size(1) + Dout = self.w_O.size(1) + CH = self.caterpillar_height + CL = self.caterpillar_length + + assert ( + t0 >= CL and (t1 - t0) % CL == 0 + ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" + + if bs.init_cache: + self.rec_V = X.new_zeros(N, CH, T, DV) + self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] + self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.cache_Y = X.new_zeros(N, T, Dout) + + ###################################################################### + # Compute the recurrent state + + G = ( + torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + ).sigmoid() + + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + + A = 1 - G.sum(1) + gated_V = torch.einsum("nhet,nhtd->netd", G, V) + gated_K = torch.einsum("nhet,nhtd->netd", G, K) + + init_rec_V = self.rec_V[:, :, t0 - CL : t0] + init_rec_K = self.rec_K[:, :, t0 - CL : t0] + + A = A.unflatten(2, (-1, CL)) + gated_V = gated_V.unflatten(2, (-1, CL)) + gated_K = gated_K.unflatten(2, (-1, CL)) + + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + + self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) + self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + + ###################################################################### + # compute the readout + + Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) + + uv = moving_window( + self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + ) + + uk = moving_window( + self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + ) + + ar = torch.einsum( + "nhtd,nftld->nhtfl", + Q, + uk, + ) / math.sqrt(DK) + + ar = ar.flatten(3).softmax(dim=3).view(ar.size()) + + ar = F.dropout(ar, self.attention_dropout, self.training) + + Y = torch.einsum( + "nhtfl,nftld->nthd", + ar, + uv, + ).flatten(2) + + self.cache_Y[:, t0:t1] = Y @ self.w_O + + return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) + + +############################## + + +class QKVAttention(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout + self.record_attention = False + + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def forward(self, bs): + x_q = bs.x + + assert ( + self.causal or bs.complete() + ), "Partial evaluation is only possible for causal models" + + if bs.init_cache: + self.cache_k = x_q.new_zeros( + x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) + ) + self.cache_v = x_q.new_zeros( + x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1) + ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q) + + self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum( + "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k + ) + self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum( + "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v + ) + + a = torch.einsum( + "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb] + ) / math.sqrt(self.w_q.size(1)) + + if self.causal: + if bs.init_cache: + self.cache_attzero = ( + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] + ) + a = a.masked_fill( + self.cache_attzero[ + :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb + ], + float("-inf"), + ) + + a = a.softmax(dim=3) + + if self.record_attention: + self.a = a + + a = F.dropout(a, self.attention_dropout, self.training) + + y = torch.einsum( + "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb] + ).flatten(2) + + self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class MyGPT(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_lines=None, + caterpillar_height=None, + dim_rec_v=-1, + causal=False, + dropout=0.0, + len_max=1e5, + attention_layer="kvrec", + ): + super().__init__() + + assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + + if attention_layer == "caterpillar": + assert nb_lines % caterpillar_height == 0 + self.caterpillar_length = nb_lines // caterpillar_height + self.caterpillar_height = caterpillar_height + else: + self.caterpillar_length = -1 + self.caterpillar_height = -1 + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + AddPositionalEncoding(len_max), + ) + + trunk_blocks = [] + + def attlayer(): + if attention_layer == "mha": + return QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + ) + elif attention_layer == "dumbrec": + return DumbRec( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + ) + elif attention_layer == "kvrec": + return KVRec( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + ) + elif attention_layer == "caterpillar": + return Caterpillar( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + ) + else: + raise ValueError(f"Unknown attention type {attention_layer}.") + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + attlayer(), + ), + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + self.reset_inner_loss() + + def forward(self, bs): + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache) + + # To make the code simpler in the Caterpillar layer, we pad + # here. It's unclear if/how much it hurts computationaly by + # increasing the sequence length for the other layers + + if self.caterpillar_length > 0: + original_nb = bs.nb + if bs.nb % self.caterpillar_length > 0: + bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length + + bs = BracketedSequence( + F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)), + bs.first + self.caterpillar_length, + bs.nb, + bs.init_cache, + ) + + bs = self.embedding(bs) + bs = self.trunk(bs) + bs = self.readout(bs) + + if self.caterpillar_length > 0: + bs = BracketedSequence( + F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)), + bs.first - self.caterpillar_length, + original_nb, + bs.init_cache, + ) + + return bs + + # ar_mask is a tensor with 0s and 1s, of same shape as input, with + # 1s where tokens should be generated. The others are kept + # unchanged. + + def masked_inplace_autoregression( + self, + input_src, + ar_mask_src, + forbidden_tokens=None, + deterministic_synthesis=False, + ): + input = input_src.to(self.readout.f.weight.device) + ar_mask = ar_mask_src.to(self.readout.f.weight.device) + to_generate = (ar_mask.sum(0) > 0).nonzero() + if to_generate.min() > 0: + self( + BracketedSequence(input, 0, to_generate.min(), True) + ) # Needed to initialize the model's cache + for s in range(to_generate.min(), to_generate.max() + 1): + output = self(BracketedSequence(input, s, 1, s == 0)).x + logits = output[:, s] + if forbidden_tokens is not None: + logits = logits.masked_fill(forbidden_tokens, float("-inf")) + if deterministic_synthesis: + t_next = logits.argmax(1) + else: + dist = torch.distributions.categorical.Categorical(logits=logits) + t_next = dist.sample() + input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + + input_src.copy_(input) + + def reset_inner_loss(self): + for m in self.modules(): + if m is not self and hasattr(m, "reset_inner_loss"): + m.reset_inner_loss() + + def get_inner_loss(self): + l = torch.tensor([0.0], device=self.readout.f.weight.device) + for m in self.modules(): + if m is not self and hasattr(m, "get_inner_loss"): + l += m.get_inner_loss() + return l + + def record_attention(self, v=True): + for m in self.modules(): + if isinstance(m, QKVAttention): + m.record_attention = v + + def retrieve_attention(self): + a = [] + for m in self.modules(): + if isinstance(m, QKVAttention): + a.append(m.a) + return a + + +###################################################################### + +if __name__ == "__main__": + print("Basic check.") + + m = Caterpillar( + dim_in=4, + dim_qk=3, + dim_v=7, + nb_heads=1, + caterpillar_length=7, + caterpillar_height=3, + attention_dropout=0.0, + ) + + m.reset_inner_loss() + x = torch.randn(1, 21 + 2 * 7, 4) + y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28] + y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28] + y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21] + y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28] + print((y1 - y2).abs().max()) + print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max()) + exit(0) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + vocabulary_size = 128 + x = torch.randint(vocabulary_size, (6, 1024)) + + model = MyGPT( + vocabulary_size=vocabulary_size, + dim_model=512, + dim_keys=64, + dim_hidden=2048, + nb_heads=8, + nb_lines=128, + nb_blocks=12, + dropout=0.1, + causal=True, + ) + + x = x.to(device) + model.to(device) + + import time, sys + + # import torchvision.models as models + # from torch.profiler import profile, record_function, ProfilerActivity + + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: + # with record_function("model_inference"): + + model.eval() + for i in range(3): + start_time = time.perf_counter() + for k in range(10): + model(BracketedSequence(x)) + duration = time.perf_counter() - start_time + print(duration) + sys.stdout.flush() + + # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # print("##############################################################") + # y2 = torch.randn_like(y1) + # for s in range(x.size(1)): + # z = model(BracketedSequence(x, s, 1)) + # y2[:, s : s + 1] = z.slice() + + # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") + +###################################################################### diff --git a/picoclvr.py b/picoclvr.py new file mode 100755 index 0000000..0cd3062 --- /dev/null +++ b/picoclvr.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math +import torch, torchvision +import torch.nn.functional as F + +color_name2rgb = { + "white": [255, 255, 255], + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "black": [0, 0, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + "light_salmon": [255, 160, 122], + "orange_red": [255, 69, 0], + "dark_orange": [255, 140, 0], + "orange": [255, 165, 0], + "gold": [255, 215, 0], + "dark_golden_rod": [184, 134, 11], + "golden_rod": [218, 165, 32], + "pale_golden_rod": [238, 232, 170], + "dark_khaki": [189, 183, 107], + "khaki": [240, 230, 140], + "olive": [128, 128, 0], + "yellow_green": [154, 205, 50], + "dark_olive_green": [85, 107, 47], + "olive_drab": [107, 142, 35], + "lawn_green": [124, 252, 0], + "chartreuse": [127, 255, 0], + "green_yellow": [173, 255, 47], + "dark_green": [0, 100, 0], + "forest_green": [34, 139, 34], + "lime": [0, 255, 0], + "lime_green": [50, 205, 50], + "light_green": [144, 238, 144], + "pale_green": [152, 251, 152], + "dark_sea_green": [143, 188, 143], + "medium_spring_green": [0, 250, 154], + "spring_green": [0, 255, 127], + "sea_green": [46, 139, 87], + "medium_aqua_marine": [102, 205, 170], + "medium_sea_green": [60, 179, 113], + "light_sea_green": [32, 178, 170], + "dark_slate_gray": [47, 79, 79], + "teal": [0, 128, 128], + "dark_cyan": [0, 139, 139], + "aqua": [0, 255, 255], + "cyan": [0, 255, 255], + "light_cyan": [224, 255, 255], + "dark_turquoise": [0, 206, 209], + "turquoise": [64, 224, 208], + "medium_turquoise": [72, 209, 204], + "pale_turquoise": [175, 238, 238], + "aqua_marine": [127, 255, 212], + "powder_blue": [176, 224, 230], + "cadet_blue": [95, 158, 160], + "steel_blue": [70, 130, 180], + "corn_flower_blue": [100, 149, 237], + "deep_sky_blue": [0, 191, 255], + "dodger_blue": [30, 144, 255], + "light_blue": [173, 216, 230], + "sky_blue": [135, 206, 235], + "light_sky_blue": [135, 206, 250], + "midnight_blue": [25, 25, 112], + "navy": [0, 0, 128], + "dark_blue": [0, 0, 139], + "medium_blue": [0, 0, 205], + "royal_blue": [65, 105, 225], + "blue_violet": [138, 43, 226], + "indigo": [75, 0, 130], + "dark_slate_blue": [72, 61, 139], + "slate_blue": [106, 90, 205], + "medium_slate_blue": [123, 104, 238], + "medium_purple": [147, 112, 219], + "dark_magenta": [139, 0, 139], + "dark_violet": [148, 0, 211], + "dark_orchid": [153, 50, 204], + "medium_orchid": [186, 85, 211], + "purple": [128, 0, 128], + "thistle": [216, 191, 216], + "plum": [221, 160, 221], + "violet": [238, 130, 238], + "magenta": [255, 0, 255], + "orchid": [218, 112, 214], + "medium_violet_red": [199, 21, 133], + "pale_violet_red": [219, 112, 147], + "deep_pink": [255, 20, 147], + "hot_pink": [255, 105, 180], + "light_pink": [255, 182, 193], + "pink": [255, 192, 203], + "antique_white": [250, 235, 215], + "beige": [245, 245, 220], + "bisque": [255, 228, 196], + "blanched_almond": [255, 235, 205], + "wheat": [245, 222, 179], + "corn_silk": [255, 248, 220], + "lemon_chiffon": [255, 250, 205], + "light_golden_rod_yellow": [250, 250, 210], + "light_yellow": [255, 255, 224], + "saddle_brown": [139, 69, 19], + "sienna": [160, 82, 45], + "chocolate": [210, 105, 30], + "peru": [205, 133, 63], + "sandy_brown": [244, 164, 96], + "burly_wood": [222, 184, 135], + "tan": [210, 180, 140], + "rosy_brown": [188, 143, 143], + "moccasin": [255, 228, 181], + "navajo_white": [255, 222, 173], + "peach_puff": [255, 218, 185], + "misty_rose": [255, 228, 225], + "lavender_blush": [255, 240, 245], + "linen": [250, 240, 230], + "old_lace": [253, 245, 230], + "papaya_whip": [255, 239, 213], + "sea_shell": [255, 245, 238], + "mint_cream": [245, 255, 250], + "slate_gray": [112, 128, 144], + "light_slate_gray": [119, 136, 153], + "light_steel_blue": [176, 196, 222], + "lavender": [230, 230, 250], + "floral_white": [255, 250, 240], + "alice_blue": [240, 248, 255], + "ghost_white": [248, 248, 255], + "honeydew": [240, 255, 240], + "ivory": [255, 255, 240], + "azure": [240, 255, 255], + "snow": [255, 250, 250], + "silver": [192, 192, 192], + "gainsboro": [220, 220, 220], + "white_smoke": [245, 245, 245], +} + +color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())]) +color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())]) + +###################################################################### + + +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [] + + for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]: + s += [f"there is {c_r}"] + + if square_i[r] >= height - height // 3: + s += [f"{c_r} bottom"] + if square_i[r] < height // 3: + s += [f"{c_r} top"] + if square_j[r] >= width - width // 3: + s += [f"{c_r} right"] + if square_j[r] < width // 3: + s += [f"{c_r} left"] + + for t, c_t in [ + (k, color_id2name[square_c[k].item()]) for k in range(nb_squares) + ]: + if square_i[r] > square_i[t]: + s += [f"{c_r} below {c_t}"] + if square_i[r] < square_i[t]: + s += [f"{c_r} above {c_t}"] + if square_j[r] > square_j[t]: + s += [f"{c_r} right of {c_t}"] + if square_j[r] < square_j[t]: + s += [f"{c_r} left of {c_t}"] + + return s + + +###################################################################### + +# Generates sequences + + +def generate( + nb, + height, + width, + max_nb_squares=5, + max_nb_properties=10, + nb_colors=5, + pruner=None, +): + assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 + + descr = [] + + for n in range(nb): + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + square_position = torch.randperm(height * width)[:nb_squares] + + # color 0 is white and reserved for the background + square_c = torch.randperm(nb_colors)[:nb_squares] + 1 + square_i = square_position.div(width, rounding_mode="floor") + square_j = square_position % width + + img = torch.zeros(height * width, dtype=torch.int64) + for k in range(nb_squares): + img[square_position[k]] = square_c[k] + + # generates all the true properties + + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) + + if pruner is not None: + s = list(filter(pruner, s)) + + # picks at most max_nb_properties at random + + nb_properties = torch.randint(max_nb_properties, (1,)) + 1 + s = ( + " ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]]) + + " " + + " ".join([f"{color_id2name[n.item()]}" for n in img]) + ) + + descr += [s] + + return descr + + +###################################################################### + +# Extracts the image after in descr as a 1x3xHxW tensor + + +def descr2img(descr, height, width): + result = [] + + def token2color(t): + try: + return color_name2rgb[t] + except KeyError: + return [128, 128, 128] + + for d in descr: + d = d.split("")[1] + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width) + result.append(img) + + return torch.cat(result, 0) + + +###################################################################### + +# Returns all the properties of the image after in descr + + +def descr2properties(descr, height, width): + if type(descr) == list: + return [descr2properties(d, height, width) for d in descr] + + d = descr.split("") + img_tokens = d[-1] if len(d) > 1 else "" + img_tokens = img_tokens.strip().split(" ")[: height * width] + if len(img_tokens) != height * width: + return [] + + seen = {} + for k, x in enumerate(img_tokens): + if x != color_id2name[0]: + if x in color_name2rgb: + if x in seen: + return [] + else: + return [] + seen[x] = (color_name2id[x], k // width, k % width) + + square_infos = tuple(zip(*seen.values())) + + if square_infos: + square_c = torch.tensor(square_infos[0]) + square_i = torch.tensor(square_infos[1]) + square_j = torch.tensor(square_infos[2]) + else: + square_c = torch.tensor([]) + square_i = torch.tensor([]) + square_j = torch.tensor([]) + + s = all_properties(height, width, len(seen), square_i, square_j, square_c) + + return s + + +###################################################################### + +# Returns a triplet composed of (1) the total number of properties +# before in descr, (2) the total number of properties the image +# after verifies, and (3) the number of properties in (1) not in +# (2) + + +def nb_properties(descr, height, width, pruner=None): + if type(descr) == list: + return [nb_properties(d, height, width, pruner) for d in descr] + + d = descr.split("", 1) + if len(d) == 0: + return 0 + d = d[0].strip().split("") + d = [x.strip() for x in d] + + all_properties = set(descr2properties(descr, height, width)) + + if pruner is None: + requested_properties = set(d) + else: + requested_properties = set(filter(pruner, d)) + + missing_properties = requested_properties - all_properties + + return (len(requested_properties), len(all_properties), len(missing_properties)) + + +###################################################################### + +if __name__ == "__main__": + for n in range(16): + descr = generate(nb=1, height=12, width=16) + + print(nb_properties(descr, height=12, width=16)) + + with open(f"picoclvr_example_{n:02d}.txt", "w") as f: + for d in descr: + f.write(f"{d}\n\n") + + img = descr2img(descr, height=12, width=16) + if img.size(0) == 1: + img = F.pad(img, (1, 1, 1, 1), value=64) + + torchvision.utils.save_image( + img / 255.0, + f"picoclvr_example_{n:02d}.png", + padding=1, + nrow=4, + pad_value=0.8, + ) + + import time + + start_time = time.perf_counter() + descr = generate(nb=1000, height=12, width=16) + end_time = time.perf_counter() + print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + +###################################################################### diff --git a/problems.py b/problems.py new file mode 100755 index 0000000..9e368c2 --- /dev/null +++ b/problems.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +class Problem: + def generate_sequences(self, nb): + pass + + def seq2str(self, seq): + return "[NOT IMPLEMENTED]" + + def compute_nb_correct(self, input, ar_mask, result): + nb_total = ar_mask.sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() + return nb_total, nb_correct + + +#################### + + +class ProblemDegradation(Problem): + def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False): + assert value_max // nb_state_tokens >= 2 + self.nb_state_tokens = nb_state_tokens + self.nb_time_steps = nb_time_steps + self.value_max = value_max + self.hard = hard + + def generate_sequences(self, nb): + x = ( + torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0 + ).long() * self.value_max + seq = [x] + + for t in range(self.nb_time_steps - 1): + v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long() + u = (v.max(dim=-1, keepdim=True).values == v).long() + n = ( + (u * x) + .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size())) + .sum(dim=-1, keepdim=True) + ) + m = 1 + ((n - 1) * torch.rand(n.size())).long() + x = ( + x + + m * u.roll(shifts=-1, dims=-1) + - n * u + + (n - m) * u.roll(shifts=1, dims=-1) + ) + seq.append(x) + + if self.hard: + seq.reverse() + + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) + + def compute_nb_correct(self, input, ar_mask, result): + nb_total = result.size(0) + nb_correct = 0 + e = result.new_zeros(self.nb_state_tokens) + + for seq in result: + states = list(seq.split(self.nb_state_tokens)) + if self.hard: + states.reverse() + + d = states[0] + j = d.sort(descending=True).indices[0] + e.zero_() + e[j] = self.value_max + if (d - e).abs().sum() == 0: + nb_errors = 0 + for k in range(len(states) - 1): + d = states[k + 1] - states[k] + j = d.sort(descending=False).indices[0] + if ( + d[j] == 0 + or d[j] > self.value_max // 4 + or d[(j + 1) % e.size(0)] <= 0 + or d[(j + 1) % e.size(0)] >= -d[j] + ): + nb_errors += 1 + else: + e.zero_() + e[j] = d[j] + e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)] + e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j] + if (d - e).abs().sum() > 0: + nb_errors += 1 + if nb_errors == 0: + nb_correct += 1 + + return nb_total, nb_correct + + def seq2str(self, seq): + return " | ".join( + [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)] + ) + + +#################### + + +class ProblemMemory(Problem): + def __init__(self, len_total=32): + self.len_total = len_total + self.max_len_pattern = 5 + self.nb_noise_tokens = 10 + self.start_pattern_token = 0 + self.end_pattern_token = 1 + self.start_result_token = 2 + self.end_result_token = 3 + self.token_string = "[]<>" + "".join( + [chr(ord("a") + k) for k in range(self.nb_noise_tokens)] + ) + + def generate_sequences(self, nb): + sequences = ( + torch.randint(self.nb_noise_tokens, (nb, self.len_total)) + + self.end_result_token + + 1 + ) + len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1 + pattern_positions = torch.randint( + self.len_total - (5 + 2 * self.max_len_pattern), (nb,) + ) + k = self.len_total - (3 + self.max_len_pattern) + for i in range(nb): + l = len_patterns[i] + j = pattern_positions[i] + sequences[i, j] = self.start_pattern_token + sequences[i, j + l + 2] = self.end_pattern_token + sequences[i, k] = self.start_result_token + sequences[i, k + l + 2] = self.end_result_token + sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l] + + j = torch.arange(self.len_total)[None, :] + ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long() + + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.token_string[x.item()] for x in seq) + + +class ProblemTwoTargets(Problem): + def __init__(self, len_total=10, len_targets=3): + assert len_targets >= 3 + assert len_total >= 3 * len_targets - 1 + self.len_total = len_total + self.len_targets = len_targets + + def generate_sequences(self, nb): + k = torch.arange(self.len_total)[None, :] + s = torch.randint(10, (nb, self.len_total)) + l = torch.rand(nb, self.len_total) + l = l * (k <= self.len_total - self.len_targets).long() + k1 = l.argmax(dim=1, keepdim=True) + m = (k != k1).long() * (k != k1 + self.len_targets - 1).long() + s = s * m + 10 * (1 - m) + l = l * ( + 1 + - (k + self.len_targets - 1 >= k1).long() + * (k < k1 + self.len_targets).long() + ) + k2 = l.argmax(dim=1, keepdim=True) + m = (k != k2).long() * (k != k2 + self.len_targets - 1).long() + s = s * m + 11 * (1 - m) + a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :]) + a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :]) + sequences = torch.cat( + ( + s, + torch.full((nb, 1), 12), + a1, + torch.full((nb, 1), 12), + a2, + torch.full((nb, 1), 12), + ), + 1, + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789-+|"[x.item()] for x in seq) + + +#################### + + +class ProblemByHeart(Problem): + def __init__(self, nb_sentences=100, len_prompt=8, len_result=8): + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) + self.seq[:, len_prompt] = 10 + + def generate_sequences(self, nb): + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences == 10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|"[x.item()] for x in seq) + + +#################### + + +class ProblemLearnOperator(Problem): + def __init__(self, nb_operators=100, len_source=6, len_result=9): + self.len_source = len_source + self.len_result = len_result + self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 + self.operators = F.one_hot( + torch.rand(nb_operators, len_result, len_source).argmax(-1), + num_classes=len_source, + ) + + def generate_sequences(self, nb): + nb_operators = torch.randint(self.operators.size(0), (nb,)) + operators = self.operators[nb_operators] + nb_operators = ( + nb_operators[:, None] + // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) + ) % 10 + marker1 = torch.full((nb, 1), 10) + source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + marker2 = torch.full((nb, 1), 11) + result = operators.bmm(source[:, :, None]).squeeze(-1) + sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>"[x.item()] for x in seq) + + +#################### + + +class ProblemGuessOperator(Problem): + def __init__(self, len_source=5, len_result=8): + self.len_source = len_source + self.len_result = len_result + + def generate_sequences(self, nb): + operators = F.one_hot( + torch.rand(nb, self.len_result, self.len_source).argmax(-1), + num_classes=self.len_source, + ) + source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + marker1 = torch.full((nb, 1), 10) + result1 = operators.bmm(source1[:, :, None]).squeeze(-1) + marker2 = torch.full((nb, 1), 11) + source2 = torch.randint(10, (nb, self.len_source)) + marker3 = torch.full((nb, 1), 12) + result2 = operators.bmm(source2[:, :, None]).squeeze(-1) + + sequences = torch.cat( + (source1, marker1, result1, marker2, source2, marker3, result2), 1 + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789>|~"[x.item()] for x in seq) + + +#################### + + +class ProblemAddition(Problem): + def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False): + self.nb_digits = nb_digits + self.zero_padded = zero_padded + self.inverted_result = inverted_result + self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")]) + self.id2char = dict([(n, c) for c, n in self.char2id.items()]) + + def tensorize(self, strings): + len_max = max([len(x) for x in strings]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "$" * (len_max - len(s))] + for s in strings + ] + ) + ], + 0, + ) + + def generate_sequences(self, nb): + sequences = [] + for k in range(nb): + a, b = torch.randint(10**self.nb_digits, (2,)) + c = a + b + a, b, c = str(a.item()), str(b.item()), str(c.item()) + if self.zero_padded: + a = "0" * (self.nb_digits - len(a)) + a + b = "0" * (self.nb_digits - len(b)) + b + c = "0" * (self.nb_digits + 1 - len(c)) + c + if self.inverted_result: + c = c[::-1] + sequences.append(f"{a}+{b}={c}$") + + sequences = self.tensorize(sequences) + ar_mask = (sequences == self.char2id["="]).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.id2char[x.item()] for x in seq) + + +#################### + + +class ProblemMixing(Problem): + def __init__( + self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True + ): + self.height = height + self.width = width + self.nb_time_steps = nb_time_steps + self.hard = hard + self.random_start = random_start + + def start_random(self, nb): + y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1) + + if self.random_start: + i = ( + torch.arange(self.height) + .reshape(1, -1, 1) + .expand(nb, self.height, self.width) + ) + j = ( + torch.arange(self.width) + .reshape(1, 1, -1) + .expand(nb, self.height, self.width) + ) + + ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1) + rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1) + + m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) + + y = y * m + self.height * self.width * (1 - m) + + y = y.reshape(nb, self.height, self.width) + + return y + + def start_error(self, x): + if self.random_start: + i = ( + torch.arange(self.height, device=x.device) + .reshape(1, -1, 1) + .expand_as(x) + ) + j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x) + + ri = ( + (x == self.height * self.width) + .long() + .sum(dim=-1) + .argmax(-1) + .view(-1, 1, 1) + ) + rj = ( + (x == self.height * self.width) + .long() + .sum(dim=-2) + .argmax(-1) + .view(-1, 1, 1) + ) + + m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) + else: + m = 1 + + x = x.flatten(1) + u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1) + + d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1) + + return d + + def moves(self, x): + y = ( + x[:, None, :, :] + .expand(-1, self.height * 2 + self.width * 2, -1, -1) + .clone() + ) + k = 0 + + for i in range(self.height): + y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1) + k += 1 + y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1) + k += 1 + + for j in range(self.width): + y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1) + k += 1 + y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1) + k += 1 + + return y + + def generate_sequences(self, nb): + x = self.start_random(nb) + + seq = [x.flatten(1)] + + for t in range(self.nb_time_steps - 1): + y = self.moves(x) + x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))] + seq.append(x.flatten(1)) + + if self.hard: + seq.reverse() + + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) + + def compute_nb_correct(self, input, ar_mask, result): + a = [ + x.reshape(result.size(0), self.height, self.width) + for x in result.split(self.height * self.width, dim=1) + ] + if self.hard: + a.reverse() + + x = a[0] + + d = self.start_error(x) + + for t in range(self.nb_time_steps - 1): + x0, x = a[t], a[t + 1] + y = self.moves(x0) + d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values + + nb_total, nb_correct = result.size(0), (d == 0).long().sum().item() + + return nb_total, nb_correct + + def seq2str(self, seq): + return " | ".join( + [ + " ".join( + [ + "-".join( + [ + f"{x:02d}" if x < self.height * self.width else "**" + for x in s + ] + ) + for s in r.split(self.width) + ] + ) + for r in seq.split(self.height * self.width) + ] + ) + + +#################### + +if __name__ == "__main__": + p = ProblemMixing(height=3, width=3, random_start=False) + + s, m = p.generate_sequences(10000) + for x in s[:5]: + print(p.seq2str(x)) + print(p.compute_nb_correct(None, None, s)) diff --git a/pscan.py b/pscan.py new file mode 100755 index 0000000..0ec7b13 --- /dev/null +++ b/pscan.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch + +###################################################################### + + +class PScan(torch.autograd.Function): + # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), + # and O(log(T)) if not core-bounded, so that + # + # Y[:, 0] = Y_init + # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] + # + # can be computed as + # + # Y[:, t] = A[:, t] * Y_init + X[:, t] + + @staticmethod + def expand_(A, X): + if A.size(1) == 1: + return + T = 2 * (A.size(1) // 2) + Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + PScan.expand_(Aa[:, :, 1], Xa[:, :, 1]) + Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) + Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) + if T < A.size(1): + X[:, -1].add_(A[:, -1].mul(X[:, -2])) + A[:, -1].mul_(A[:, -2]) + + @staticmethod + def acc_rev_(A, X): + if X.size(1) == 1: + return + T = 2 * (X.size(1) // 2) + Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1])) + B = Aa[:, :, 0].clone() + B[:, 1:].mul_(Aa[:, :-1, 1]) + PScan.acc_rev_(B, Xa[:, :, 0]) + Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0])) + if T < A.size(1): + X[:, 0].add_(A[:, 1].mul(X[:, 1])) + + # A is NxT, X is NxTxD, Y_init is NxD + # + # returns Y of same shape as X, with + # + # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0 + # = A[:, t] * Y[:, t-1] + X[:, t] otherwise + + @staticmethod + def forward(ctx, A, X, Y_init): + ctx.A = A.unsqueeze(-1).clone() + ctx.Y_init = Y_init[:, None].clone() + ctx.A_star = ctx.A.clone() + ctx.X_star = X.clone() + PScan.expand_(ctx.A_star, ctx.X_star) + return ctx.A_star * ctx.Y_init + ctx.X_star + + @staticmethod + def backward(ctx, grad_output): + U = grad_output * ctx.A_star + A = ctx.A.clone() + R = grad_output.clone() + PScan.acc_rev_(A, R) + Q = ctx.Y_init.expand_as(ctx.X_star).clone() + Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1]) + return (Q * R).sum(-1), R, U.sum(dim=1) + + +pscan = PScan.apply + +###################################################################### + +if __name__ == "__main__": + import time, sys + + A = torch.rand(17, 12, 3) + X = torch.rand(17, 12, 3, 11) + Y_init = torch.rand(17, 3, 11) + Y = pscan(A, X, Y_init) + exit(0) + + N, T, D = 2, 1047, 3 + + A = torch.rand(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + + # Iterative implementation + + y = Y_init + s = 0 + + for k in range(A.size(1)): + y = A[:, k, None] * y + X[:, k] + s = s + y + + s = s.sum() + + gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( + s, (A, X, Y_init), retain_graph=True + ) + + # parallel scan + + start_time = time.perf_counter() + for _ in range(1000): + Y = pscan(A, X, Y_init) + duration = time.perf_counter() - start_time + print(f"duration {duration}") + + s = Y.sum() + + gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) + + # print(gA) + # print(gX) + # print(gY_init) + + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY_init - gY_init_ref).norm()) + + Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) + + print((Y - torch.cat([Y1, Y2], dim=1)).norm()) diff --git a/qmlp.py b/qmlp.py new file mode 100755 index 0000000..abebfc1 --- /dev/null +++ b/qmlp.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_SEND: *.py *.sh + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +nb_quantization_levels = 101 + + +def quantize(x, xmin, xmax): + return ( + ((x - xmin) / (xmax - xmin) * nb_quantization_levels) + .long() + .clamp(min=0, max=nb_quantization_levels - 1) + ) + + +def dequantize(q, xmin, xmax): + return q / nb_quantization_levels * (xmax - xmin) + xmin + + +###################################################################### + + +def generate_sets_and_params( + batch_nb_mlps, + nb_samples, + batch_size, + nb_epochs, + device=torch.device("cpu"), + print_log=False, + save_as_examples=False, +): + data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device) + data_targets = torch.zeros( + batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device + ) + + nb_rec = 8 + nb_values = 2 # more increases the min-max gap + + rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device) + + while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: + i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 + nb = i.sum() + support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1 + support = support.sort(-1).values + support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4) + + x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1 + y = ( + ( + (x[:, None, :, 0] >= support[:, :, None, 0]).long() + * (x[:, None, :, 0] <= support[:, :, None, 1]).long() + * (x[:, None, :, 1] >= support[:, :, None, 2]).long() + * (x[:, None, :, 1] <= support[:, :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + data_input[i], data_targets[i], rec_support[i] = x, y, support + + train_input, train_targets = ( + data_input[:, :nb_samples], + data_targets[:, :nb_samples], + ) + test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:] + + q_train_input = quantize(train_input, -1, 1) + train_input = dequantize(q_train_input, -1, 1) + + q_test_input = quantize(test_input, -1, 1) + test_input = dequantize(q_test_input, -1, 1) + + if save_as_examples: + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) + print(f"{xf.size()=} {x.size()=}") + yf = ( + ( + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + full_input, full_targets = xf, yf + + q_full_input = quantize(full_input, -1, 1) + full_input = dequantize(q_full_input, -1, 1) + + for k in range(q_full_input[:10].size(0)): + with open(f"example_full_{k:04d}.dat", "w") as f: + for u, c in zip(full_input[k], full_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + for k in range(q_train_input[:10].size(0)): + with open(f"example_train_{k:04d}.dat", "w") as f: + for u, c in zip(train_input[k], train_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + hidden_dim = 32 + w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2) + b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt( + hidden_dim + ) + b2 = torch.zeros(batch_nb_mlps, 2, device=device) + + w1.requires_grad_() + b1.requires_grad_() + w2.requires_grad_() + b2.requires_grad_() + optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2) + + criterion = nn.CrossEntropyLoss() + criterion.to(device) + + for k in range(nb_epochs): + acc_train_loss = 0.0 + nb_train_errors = 0 + + for input, targets in zip( + train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy( + output.reshape(-1, output.size(-1)), targets.reshape(-1) + ) + acc_train_loss += loss.item() * input.size(0) + + wta = output.argmax(-1) + nb_train_errors += (wta != targets).long().sum(-1) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + for p in [w1, b1, w2, b2]: + m = ( + torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1) + ).long() + pq = quantize(p, -2, 2) + p[...] = (1 - m) * p + m * dequantize(pq, -2, 2) + + train_error = nb_train_errors / train_input.size(1) + acc_train_loss = acc_train_loss / train_input.size(1) + + # print(f"{k=} {acc_train_loss=} {train_error=}") + + acc_test_loss = 0 + nb_test_errors = 0 + + for input, targets in zip( + test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) + acc_test_loss += loss.item() * input.size(0) + + wta = output.argmax(-1) + nb_test_errors += (wta != targets).long().sum(-1) + + test_error = nb_test_errors / test_input.size(1) + q_params = torch.cat( + [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1 + ) + q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape( + batch_nb_mlps, -1 + ) + q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape( + batch_nb_mlps, -1 + ) + + return q_train_set, q_test_set, q_params, test_error + + +###################################################################### + + +def evaluate_q_params( + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, +): + errors = [] + nb_mlps = q_params.size(0) + + for n in range(0, nb_mlps, nb_mlps_per_batch): + batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n) + batch_q_params = q_params[n : n + batch_nb_mlps] + batch_q_set = q_set[n : n + batch_nb_mlps] + hidden_dim = 32 + w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device) + b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device) + b2 = torch.empty(batch_nb_mlps, 2, device=device) + + with torch.no_grad(): + k = 0 + for p in [w1, b1, w2, b2]: + print(f"{p.size()=}") + x = dequantize( + batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2 + ).view(p.size()) + p.copy_(x) + k += p.numel() // batch_nb_mlps + + batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3) + data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device) + data_targets = batch_q_set[:, :, 2].to(device) + + print(f"{data_input.size()=} {data_targets.size()=}") + + criterion = nn.CrossEntropyLoss() + criterion.to(device) + + acc_loss = 0.0 + nb_errors = 0 + + for input, targets in zip( + data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy( + output.reshape(-1, output.size(-1)), targets.reshape(-1) + ) + acc_loss += loss.item() * input.size(0) + wta = output.argmax(-1) + nb_errors += (wta != targets).long().sum(-1) + + errors.append(nb_errors / data_input.size(1)) + acc_loss = acc_loss / data_input.size(1) + + return torch.cat(errors) + + +###################################################################### + + +def generate_sequence_and_test_set( + nb_mlps, + nb_samples, + batch_size, + nb_epochs, + device, + nb_mlps_per_batch=1024, +): + seqs, q_test_sets, test_errors = [], [], [] + + for n in range(0, nb_mlps, nb_mlps_per_batch): + q_train_set, q_test_set, q_params, test_error = generate_sets_and_params( + batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n), + nb_samples=nb_samples, + batch_size=batch_size, + nb_epochs=nb_epochs, + device=device, + ) + + seqs.append( + torch.cat( + [ + q_train_set, + q_train_set.new_full( + ( + q_train_set.size(0), + 1, + ), + nb_quantization_levels, + ), + q_params, + ], + dim=-1, + ) + ) + + q_test_sets.append(q_test_set) + test_errors.append(test_error) + + seq = torch.cat(seqs) + q_test_set = torch.cat(q_test_sets) + test_error = torch.cat(test_errors) + + return seq, q_test_set, test_error + + +###################################################################### + +if __name__ == "__main__": + import time + + batch_nb_mlps, nb_samples = 128, 250 + + generate_sets_and_params( + batch_nb_mlps=10, + nb_samples=nb_samples, + batch_size=25, + nb_epochs=100, + device=torch.device("cpu"), + print_log=False, + save_as_examples=True, + ) + + exit(0) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + start_time = time.perf_counter() + + data = [] + + seq, q_test_set, test_error = generate_sequence_and_test_set( + nb_mlps=batch_nb_mlps, + nb_samples=nb_samples, + device=device, + batch_size=25, + nb_epochs=250, + nb_mlps_per_batch=17, + ) + + end_time = time.perf_counter() + print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second") + + q_train_set = seq[:, : nb_samples * 3] + q_params = seq[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}") + error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) + print(f"train {error_train*100}%") + error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) + print(f"test {error_test*100}%") diff --git a/rpl.py b/rpl.py new file mode 100755 index 0000000..b848afa --- /dev/null +++ b/rpl.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +def rpl_exec(program, stack): + stack = stack.copy() + for op in program: + if op == "add": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a + b) + elif op == "min": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(min(a, b)) + elif op == "max": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(max(a, b)) + elif op == "swp": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a) + stack.append(b) + elif op == "rep": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack += [b] * a + elif op == "dup": + if len(stack) > 0: + a = stack.pop() + stack.append(a) + stack.append(a) + elif op == "del": + if len(stack) > 0: + a = stack.pop() + else: + raise ValueError(f"Unknown instruction {op}") + + return stack + + +rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] + +###################################################################### + + +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): + prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() + + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): + break + + return result + + +def next_marker(seq, tokens, start=0): + pos = None + for t in tokens: + try: + i = seq.index(t, start) + if pos is None or i < pos: + pos = i + except ValueError: + pass + return pos + + +def decompose(seq): + io = [] + k = 0 + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + if o is None: + raise ValueError("Missing output markers (should be correct in the prompt)") + e = next_marker(seq, ["", ""], start=o) + if e is None: + raise ValueError( + "Missing input/output markers (should be correct in the prompt)" + ) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError( + "Invalid input/output value (should be correct in the prompt)" + ) + + k = e + + if seq[k] == "": + e = next_marker(seq, [""], start=k) + if e is None: + prog = [] + else: + prog = seq[k + 1 : e] + else: + raise ValueError("Missing (it should be in the prompt)") + + return prog, io + + +def stack_distance(target_stack, result_stack): + return abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + + +def compute_nb_errors(seq): + prog, io = decompose(seq) + + nb_total, nb_errors = 0, 0 + + stacks = [] + + if len(set(prog) - set(rpl_ops)) > 0: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, ["N/A"], False)) + nb_total += len(target_stack) + nb_errors += len(target_stack) + + else: + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) + nb_total += len(target_stack) + e = stack_distance(target_stack, result_stack) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) + + return nb_total, nb_errors, prog, stacks + + +###################################################################### + +if __name__ == "__main__": + seq = generate() + print(seq) + seq[3] = 7 + print(seq) + print(compute_nb_errors(seq)) diff --git a/snake.py b/snake.py new file mode 100755 index 0000000..8a16f9f --- /dev/null +++ b/snake.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision +import torch.nn.functional as F + + +def generate_sequences( + nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") +): + worlds = torch.randint(nb_colors, (nb, height, width), device=device) + world_prior_visits = torch.zeros(nb, height, width, device=device) + + # nb x 2 + snake_position = torch.cat( + ( + torch.randint(height, (nb, 1), device=device), + torch.randint(width, (nb, 1), device=device), + ), + 1, + ) + snake_direction = torch.randint(4, (nb,), device=device) + sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) + sequences_prior_visits = torch.zeros( + nb, 2 * length, device=device, dtype=torch.int64 + ) + i = torch.arange(nb, device=device) # [:,None] + + for l in range(length): + # nb x 3 + snake_next_direction = torch.cat( + ( + (snake_direction[:, None] - 1) % 4, + snake_direction[:, None], + (snake_direction[:, None] + 1) % 4, + ), + 1, + ) + + # nb x 3 + vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1) + vw = snake_next_direction % 2 * (snake_next_direction - 2) + + # nb x 3 x 2 + snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2) + snake_next_position = snake_position[:, None, :] + snake_next_speed + + # nb x 3 + val = torch.logical_and( + torch.logical_and( + snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height + ), + torch.logical_and( + snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width + ), + ).float() + val = ( + # The multiplicative factors bias toward moving forward + torch.rand_like(val) + * val + * torch.tensor([[1.0, 2.0, 1.0]], device=device) + ) + + # nb + j = val.argmax(1) + snake_direction = snake_next_direction[i, j] + + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 + sequences_prior_visits[:, 2 * l] = world_prior_visits[ + i, snake_position[:, 0], snake_position[:, 1] + ] + if l < prompt_length: + world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + sequences[:, 2 * l + 1] = snake_direction + + # nb x 2 + snake_position = snake_next_position[i, j] + + return sequences, sequences_prior_visits, worlds, world_prior_visits + + +# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) +# exit(0) + + +def solver(input, ar_mask): + for n in range(input.size(0)): + i, j, memory = 0, 0, {} + # print(input[n]) + # print(ar_mask[n]) + for l in range(input.size(1) // 2): + if ar_mask[n, 2 * l] == 1: + if memory.get((i, j)) is None: + input[n, 2 * l] = -1 + else: + input[n, 2 * l] = memory[(i, j)] + else: + # print(f'@3 {memory=}') + if memory.get((i, j)) is None: + memory[(i, j)] = input[n, 2 * l] + else: + assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}" + # print(f'@1 {i=} {j=}') + d = input[n, 2 * l + 1].item() + i += (d + 1) % 2 * (d - 1) + j += d % 2 * (d - 2) + # print(f'@2 {i=} {j=}') + + +def seq2str(seq): + return "".join(["NESW123456789"[i] for i in seq]) + + +###################################################################### + +if __name__ == "__main__": + train_input, train_prior_visits, _, _ = generate_sequences( + nb=20, + height=9, + width=12, + nb_colors=5, + length=50, + prompt_length=100, + ) + + print([seq2str(s) for s in train_input]) + +###################################################################### diff --git a/stack.py b/stack.py new file mode 100755 index 0000000..543f04e --- /dev/null +++ b/stack.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision + +###################################################################### + +# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack +# CODE_VAL=val + 2 * nb_stacks + + +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): + stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) + stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) + k = torch.arange(nb) + result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros( + nb, (1 + nb_digits) * nb_steps, dtype=torch.int64 + ) + + for t in range(nb_steps): + op = torch.randint(2, (nb,)) + st = torch.randint(nb_stacks, (nb,)) + op = op * (stack_counts[k, st] > 0) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] + val_pop = stack[ + k, + st, + (stack_counts[k, st] - 1).clamp(min=0), + ] + stack[k, st, stack_counts[k, st]] = val_push + recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] + stack_counts[k[op == 0], st[op == 0]] += 1 + stack_counts[k[op == 1], st[op == 1]] -= 1 + result[:, (1 + nb_digits) * t] = st * 2 + op + for d in range(nb_digits): + result[:, (1 + nb_digits) * t + 1 + d] = ( + (op * val_pop + (1 - op) * val_push) // (10**d) + ) % 10 + 2 * nb_stacks + + return result.to(device), recorded_stack_counts.to(device) + + +def remove_popped_values(seq, nb_stacks, nb_digits): + m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() + for d in range(nb_digits): + k = d + 1 + seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] + + +def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): + assert seq.size(0) % (1 + nb_digits) == 0 + s = "" + for t in range(seq.size(0) // (1 + nb_digits)): + n_op = seq[(1 + nb_digits) * t] + if t > 0: + s += " " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " + s += f"POP" if n_op % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n_op//2}" + for d in range(nb_digits): + if seq[(1 + nb_digits) * t + 1 + d] == -1: + s += " ?" + else: + s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" + return s + + +###################################################################### + +if __name__ == "__main__": + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 + seq, recorded_stack_counts = generate_sequences( + nb=nb, + nb_steps=nb_steps, + nb_stacks=nb_stacks, + nb_digits=nb_digits, + ) + + for n in range(min(10, seq.size(0))): + print( + seq_to_str( + seq[n], + nb_stacks=nb_stacks, + nb_digits=nb_digits, + recorded_stack_counts=recorded_stack_counts[n], + ) + ) + # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) + + print("-- PREPARED FOR TEST -----------------") + + remove_popped_values(seq, nb_stacks, nb_digits) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) diff --git a/tasks.py b/tasks.py new file mode 100755 index 0000000..58638ed --- /dev/null +++ b/tasks.py @@ -0,0 +1,1663 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, os, tqdm + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +from mygpt import BracketedSequence + +# from graph import save_attention_image +save_attention_image = None + +###################################################################### + + +def masked_inplace_autoregression( + model, + batch_size, + input, + ar_mask, + deterministic_synthesis, + forbidden_tokens=None, + progress_bar_desc="autoregression", + device=torch.device("cpu"), +): + assert input.size() == ar_mask.size() + + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + + if progress_bar_desc is not None: + batches = tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=(input.size(0) + batch_size - 1) // batch_size, + ) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, ar_mask in batches: + model.masked_inplace_autoregression( + input, ar_mask, forbidden_tokens, deterministic_synthesis + ) + + model.train(t) + + +###################################################################### + + +class Task: + def batches(self, split="train"): + pass + + def vocabulary_size(self): + pass + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + pass + + +#################### + +import problems + + +class SandBox(Task): + def __init__( + self, + problem, + nb_train_samples, + nb_test_samples, + batch_size, + logger=None, + device=torch.device("cpu"), + max_nb_codes=1024, + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + self.problem = problem + + self.train_input, self.train_ar_mask = self.problem.generate_sequences( + nb_train_samples + ) + self.test_input, self.test_ar_mask = self.problem.generate_sequences( + nb_test_samples + ) + + self.train_input, self.train_ar_mask = self.train_input.to( + device + ), self.train_ar_mask.to(device) + self.test_input, self.test_ar_mask = self.test_input.to( + device + ), self.test_ar_mask.to(device) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + # A bit of paranoia never hurts + assert self.nb_codes <= max_nb_codes + assert self.train_input.min() >= 0 + assert self.test_input.min() >= 0 + assert tuple(x.item() for x in self.train_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + assert tuple(x.item() for x in self.test_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + + if logger is not None: + for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]): + logger(f"train_sequences {self.problem.seq2str(s)}") + a = "".join(["01"[x.item()] for x in a]) + logger(f" {a}") + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 + ): + def compute_accuracy(input, ar_mask, logger=None): + input, ar_mask = input[:nmax], ar_mask[:nmax] + result = input.clone() * (1 - ar_mask) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + log_ground_truth = ar_mask.min() == 0 + + if logger is not None: + for sp, st in zip(result[:10], input[:10]): + logger( + f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}" + ) + if log_ground_truth: + logger( + f" {n_epoch} ground truth {self.problem.seq2str(st)}" + ) + + nb_total, nb_correct = self.problem.compute_nb_correct( + input, ar_mask, result + ) + + # nb_total = ar_mask.sum().item() + # nb_correct = ((result == input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask, logger + ) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + if save_attention_image is not None: + for k in range(10): + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + + with torch.autograd.no_grad(): + t = model.training + model.eval() + # model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + # ram = model.retrieve_attention() + # model.record_attention(False) + + # tokens_output = [c for c in self.problem.seq2str(input[0])] + # tokens_input = ["n/a"] + tokens_output[:-1] + # for n_head in range(ram[0].size(1)): + # filename = os.path.join( + # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" + # ) + # attention_matrices = [m[0, n_head] for m in ram] + # save_attention_image( + # filename, + # tokens_input, + # tokens_output, + # attention_matrices, + # k_top=10, + ##min_total_attention=0.9, + # token_gap=12, + # layer_gap=50, + # ) + # logger(f"wrote {filename}") + + +###################################################################### + +import picoclvr + + +class PicoCLVR(Task): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + token_descr = [s + [""] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in token_descr] + return torch.tensor(id_descr, device=self.device) + + # Make a list of strings from a tensor + def detensorize(self, x): + return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + + # trim all the tensors in the tuple z to remove as much token from + # left and right in the first tensor. If z is a tuple, all its + # elements are trimed according to the triming for the first + def trim(self, z, token=""): + n = self.token2id[token] + if type(z) == tuple: + x = z[0] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return tuple([t[:, a:b] for t in z]) + else: + i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return z[:, a:b] + + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_colors=5, + logger=None, + device=torch.device("cpu"), + pruner_train=None, + pruner_eval=None, + ): + super().__init__() + + def generate_descr(nb, cache_suffix, pruner): + return picoclvr.generate( + nb, + height=self.height, + width=self.width, + nb_colors=nb_colors, + pruner=pruner, + ) + + self.height = height + self.width = width + self.batch_size = batch_size + self.device = device + self.pruner_train = pruner_train + self.pruner_eval = pruner_eval + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + self.train_descr = generate_descr( + nb_train_samples, "train", pruner=self.pruner_train + ) + self.test_descr = generate_descr(nb_test_samples, "test", pruner=None) + + # Build the tokenizer + tokens = {"", ""} + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_img, self.t_nul = self.token2id[""], self.token2id[""] + + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def compute_missing_properties( + self, n_epoch, model, logger, deterministic_synthesis, pruner=None + ): + acc_nb_requested_properties = [] + acc_nb_missing_properties = [] + acc_nb_results = 0 + + for input in tqdm.tqdm( + self.test_input.split(self.batch_size), + dynamic_ncols=True, + desc=f"test-properties", + ): + result = input.clone() + ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + result_descr = self.detensorize(result) + np = picoclvr.nb_properties( + result_descr, + height=self.height, + width=self.width, + pruner=pruner, + ) + nb_requested_properties, _, nb_missing_properties = zip(*np) + acc_nb_requested_properties += nb_requested_properties + acc_nb_missing_properties += nb_missing_properties + acc_nb_results += len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "" if pruner is None else "pruned_" + logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + logger( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + logger( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + logger( + f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}" + ) + + ###################################################################### + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis) + + if self.pruner_eval is not None: + self.compute_missing_properties(n_epoch, model, self.pruner_eval) + + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [] + nb_per_primer = 8 + primer = [] + + for primer_descr in [ + "red above green green top blue right of red", + "there is red there is yellow there is blue", + "red below yellow yellow below green green below blue red right yellow left green right blue left", + "green bottom yellow bottom green left of blue yellow right of blue blue top", + ]: + primer += [primer_descr + " "] * nb_per_primer + + result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul + ) + result = torch.cat((result, fill), 1) + ar_mask = (result == self.t_nul).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + result_descr = self.detensorize(result) + + np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) + + acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np) + acc_nb_results = len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "demo_" + logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + logger( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + logger( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) + + if img.dim() == 5: + if img.size(1) == 1: + img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64) + else: + img = torch.cat( + [ + torchvision.utils.make_grid(x, padding=1, pad_value=64)[None] + for x in img + ], + 0, + ) + + image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0 + ) + logger(f"wrote {image_name}") + + +###################################################################### + + +class MNIST(Task): + def __init__( + self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + ): + super().__init__() + + self.nb_train_samples = (nb_train_samples,) + self.nb_test_samples = (nb_test_samples,) + self.batch_size = batch_size + self.device = device + data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True) + self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long() + data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) + self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) + ar_mask = torch.full_like(results, 1) + masked_inplace_autoregression( + model, + self.batch_size, + results, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + logger(f"wrote {image_name}") + + +###################################################################### + +import maze + + +class Maze(Task): + def map2seq(self, *m): + return torch.cat([x.flatten(1) for x in m], 1) + + def seq2map(self, s): + s = s.reshape(s.size(0), -1, self.height, self.width) + return (s[:, k] for k in range(s.size(1))) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_walls, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + + train_mazes, train_paths, _ = maze.create_maze_data( + nb_train_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), + ) + self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) + + test_mazes, test_paths, _ = maze.create_maze_data( + nb_test_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), + ) + self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def compute_error( + self, model, split="train", nb_to_use=-1, deterministic_synthesis=False + ): + nb_total, nb_correct = 0, 0 + count = torch.zeros( + self.width * self.height, + self.width * self.height, + device=self.device, + dtype=torch.int64, + ) + + for input in self.batches(split, nb_to_use): + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + mazes, paths = self.seq2map(result) + path_correctness = maze.path_correctness(mazes, paths) + nb_correct += path_correctness.long().sum() + nb_total += mazes.size(0) + + optimal_path_lengths = ( + (input[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + predicted_path_lengths = ( + (result[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + optimal_path_lengths = optimal_path_lengths[path_correctness] + predicted_path_lengths = predicted_path_lengths[path_correctness] + count[optimal_path_lengths, predicted_path_lengths] += 1 + + if count.max() == 0: + count = None + else: + count = count[ + : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1 + ] + + return nb_total, nb_correct, count + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + train_nb_total, train_nb_correct, count = self.compute_error( + model, + "train", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct, count = self.compute_error( + model, + "test", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%") + with open( + os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" + ) as f: + for i in range(count.size(0)): + for j in range(count.size(1)): + eol = " " if j < count.size(1) - 1 else "\n" + f.write(f"{count[i,j]}{eol}") + + input = self.test_input[:48] + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + + filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png") + maze.save_image( + filename, + mazes=mazes, + target_paths=paths, + predicted_paths=predicted_paths, + path_correct=maze.path_correctness(mazes, predicted_paths), + path_optimal=maze.path_optimality(paths, predicted_paths), + ) + logger(f"wrote {filename}") + + +###################################################################### + + +import snake + + +class Snake(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_colors, + length, + prompt_length, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + self.prompt_length = prompt_length + + self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences( + nb_train_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences( + nb_test_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_correct(input, prior_visits): + result = input.clone() + i = torch.arange(result.size(1), device=result.device)[None, :] + ar_mask = ( + torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0) + .long() + .expand_as(result) + ) + result *= 1 - ar_mask + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + nb_total = ((prior_visits > 0) * ar_mask).sum() + + nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct( + self.test_input[:1000], self.test_prior_visits[:1000] + ) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + +###################################################################### + + +import stack + + +class Stack(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + logger, + nb_steps, + nb_stacks, + nb_digits, + fraction_values_for_train=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.nb_steps = nb_steps + self.nb_stacks = nb_stacks + self.nb_digits = nb_digits + self.device = device + + if fraction_values_for_train is None: + values_for_train = None + values_for_test = None + else: + all = torch.randperm(10**nb_digits) + nb_for_train = int(all.size(0) * fraction_values_for_train) + values_for_train = all[:nb_for_train] + values_for_test = all[nb_for_train:] + + self.train_input, self.train_stack_counts = stack.generate_sequences( + nb_train_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_train, + self.device, + ) + + self.test_input, self.test_stack_counts = stack.generate_sequences( + nb_test_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_test, + self.device, + ) + + i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks) + counts = self.test_stack_counts.flatten()[i.flatten()] + counts = F.one_hot(counts).sum(0) + logger(f"test_pop_stack_counts {counts}") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_correct(input): + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + errors = ((result != input).long() * ar_mask).reshape( + -1, 1 + self.nb_digits + ) + ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits) + + nb_total = ar_mask.max(1).values.sum() + nb_correct = nb_total - errors.max(1).values.sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000]) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + ############################################################## + # Log a few generated sequences + input = self.test_input[:10, : 12 * (1 + self.nb_digits)] + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + for n in range(result.size(0)): + logger( + f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + ############################################################## + + +###################################################################### + +import rpl + + +class RPL(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [ + self.token2id[str(c)] + for c in s + [""] * (len_max - len(s)) + ] + for s in sequences + ] + ) + ], + 0, + ) + + def seq2str(self, seq): + return " ".join([self.id2token[i] for i in seq]) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + nb_starting_values=3, + max_input=9, + prog_len=6, + nb_runs=5, + no_prog=False, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + self.no_prog = no_prog + + train_sequences = [ + rpl.generate( + nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) + for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data") + ] + + test_sequences = [ + rpl.generate( + nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) + for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") + ] + + symbols = list( + set([""] + [x for l in train_sequences + test_sequences for x in l]) + ) + val_max = max([x if type(x) is int else 0 for x in symbols]) + symbols = list(filter(lambda x: type(x) is str, symbols)) + symbols.sort() + symbols += [str(n) for n in range(val_max + 1)] + self.token2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2token = dict([(n, c) for c, n in self.token2id.items()]) + + self.t_nul = self.token2id[""] + self.t_input = self.token2id[""] + self.t_output = self.token2id[""] + self.t_prog = self.token2id[""] + self.t_end = self.token2id[""] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + if no_prog: + # Excise the program from every train and test example + k = torch.arange(self.train_input.size(1), device=self.train_input.device)[ + None, : + ] + p = ( + ((self.train_input == self.t_prog).long() * k) + .max(1, keepdim=True) + .values + ) + self.train_input = ( + self.train_input * (k <= p).long() + + self.t_end * (k == p + 1).long() + + self.t_nul * (k > p + 1).long() + ) + k = torch.arange(self.test_input.size(1), device=self.test_input.device)[ + None, : + ] + p = ( + ((self.test_input == self.t_prog).long() * k) + .max(1, keepdim=True) + .values + ) + self.test_input = ( + self.test_input * (k <= p).long() + + self.t_end * (k == p + 1).long() + + self.t_nul * (k > p + 1).long() + ) + + if logger is not None: + logger(f"value_max {val_max}") + for x in self.train_input[:25]: + end = (x != self.t_nul).nonzero().max().item() + 1 + seq = [self.id2token[i.item()] for i in x[:end]] + s = " ".join(seq) + logger(f"example_seq {s}") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.t_nul).max(0).values.nonzero().max() + 3 + batch = batch[:, :last].to(self.device) + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + # -------------------------------------------------------------------- + def compute_nb_errors_prog(input, nb_to_log=0): + result = input.clone() + s = (result == self.t_prog).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + sum_nb_total, sum_nb_errors = 0, 0 + for one_input, one_result in zip(input, result): + seq = [self.id2token[i.item()] for i in one_result] + nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) + sum_nb_total += 1 + sum_nb_errors += 0 if nb_errors == 0 else 1 + if nb_to_log > 0: + gt_seq = [self.id2token[i.item()] for i in one_input] + _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) + gt_prog = " ".join([str(x) for x in gt_prog]) + prog = " ".join([str(x) for x in prog]) + comment = "*" if nb_errors == 0 else "-" + logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]") + for start_stack, target_stack, result_stack, correct in stacks: + comment = "*" if correct else "-" + start_stack = " ".join([str(x) for x in start_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + result_stack = " ".join([str(x) for x in result_stack]) + logger( + f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]" + ) + nb_to_log -= 1 + + return sum_nb_total, sum_nb_errors + + # -------------------------------------------------------------------- + def compute_nb_errors_output(input, nb_to_log=0): + result = input.clone() + k = torch.arange(result.size(1), device=result.device)[None, :] + last_output_idx = ( + ((result == self.t_output) * k).max(dim=1, keepdim=True).values + ) + first_prog_idx = ( + ((result == self.t_prog) * k).max(dim=1, keepdim=True).values + ) + ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long() + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + sum_nb_total, sum_nb_errors = 0, 0 + for one_input, one_result, i, j in zip( + input, result, last_output_idx, first_prog_idx + ): + seq = [self.id2token[i.item()] for i in one_result] + sum_nb_total += 1 + correct = (one_input - one_result).abs().max() == 0 + sum_nb_errors += 0 if correct else 1 + if nb_to_log > 0: + result_stack = [ + self.id2token[i.item()] for i in one_result[i : j + 1] + ] + target_stack = [ + self.id2token[i.item()] for i in one_input[i : j + 1] + ] + comment = "*" if correct else "-" + result_stack = " ".join([str(x) for x in result_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + logger( + f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]" + ) + nb_to_log -= 1 + + return sum_nb_total, sum_nb_errors + + # -------------------------------------------------------------------- + + if not self.no_prog: + test_nb_total, test_nb_errors = compute_nb_errors_prog( + self.test_input[:1000].to(self.device), nb_to_log=10 + ) + + logger( + f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}") + + test_nb_total, test_nb_errors = compute_nb_errors_output( + self.test_input[:1000].to(self.device), nb_to_log=10 + ) + + logger( + f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + if save_attention_image is not None: + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + last = (input != self.t_nul).max(0).values.nonzero().max() + 3 + input = input[:, :last].to(self.device) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + ram = model.retrieve_attention() + model.record_attention(False) + + tokens_output = [self.id2token[i.item()] for i in input[0]] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(ram[0].size(1)): + filename = os.path.join( + result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf" + ) + attention_matrices = [m[0, n_head] for m in ram] + save_attention_image( + filename, + tokens_input, + tokens_output, + attention_matrices, + k_top=10, + # min_total_attention=0.9, + token_gap=12, + layer_gap=50, + ) + logger(f"wrote {filename}") + + +###################################################################### + + +import expr + + +class Expr(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "#" * (len_max - len(s))] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + nb_variables, + sequence_length, + operand_max, + result_max, + batch_size, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + + train_sequences = expr.generate_sequences( + nb_train_samples, + nb_variables=nb_variables, + length=sequence_length, + operand_max=operand_max, + result_max=result_max, + ) + + test_sequences = expr.generate_sequences( + nb_test_samples, + nb_variables=nb_variables, + length=sequence_length, + operand_max=operand_max, + result_max=result_max, + ) + + symbols = list(set("#" + "".join(train_sequences + test_sequences))) + symbols.sort() + + self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2char = dict([(n, c) for c, n in self.char2id.items()]) + + self.filler, self.space = self.char2id["#"], self.char2id[" "] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.filler).max(0).values.nonzero().max() + 3 + batch = batch[:, :last] + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def seq2str(self, s): + return "".join([self.id2char[k.item()] for k in s]) + + def produce_results( + self, + n_epoch, + model, + result_dir, + logger, + deterministic_synthesis, + input_file=None, + ): + def compute_nb_correct(input): + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + nb_total = input.size(0) + nb_correct = (input == result).long().min(1).values.sum() + + ####################################################################### + # Comput predicted vs. true variable values + + nb_delta = torch.zeros(5, dtype=torch.int64) + nb_missed = 0 + + values_input = expr.extract_results([self.seq2str(s) for s in input]) + values_result = expr.extract_results([self.seq2str(s) for s in result]) + + filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt") + + with open(filename, "w") as f: + for i, r in zip(values_input, values_result): + for n, vi in i.items(): + vr = r.get(n) + f.write(f"{vi} {-1 if vr is None else vr}\n") + + if vr is None or vr < 0: + nb_missed += 1 + else: + d = abs(vr - vi) + if d >= nb_delta.size(0): + nb_missed += 1 + else: + nb_delta[d] += 1 + + ###################################################################### + + return nb_total, nb_correct, nb_delta, nb_missed + + ( + test_nb_total, + test_nb_correct, + test_nb_delta, + test_nb_missed, + ) = compute_nb_correct(self.test_input[:10000]) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + nb_total = test_nb_delta.sum() + test_nb_missed + for d in range(test_nb_delta.size(0)): + logger( + f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%" + ) + logger( + f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%" + ) + + ############################################################## + # Log a few generated sequences + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) + + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler + + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + correct = (1 - ar_mask) * self.space + ar_mask * input + for n in range(result.size(0)): + comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" + logger(f"test_after {self.seq2str(result[n])} {comment}") + logger(f"truth {self.seq2str(correct[n])}") + ############################################################## + + +###################################################################### + +import grid + + +class Grid(Task): + # Make a tensor from a list of strings + def str2tensor(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + token_descr = [s + ["#"] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in token_descr] + return torch.tensor(id_descr, device=self.device) + + # Make a list of strings from a tensor + def tensor2str(self, x): + return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + + # trim all the tensors in the tuple z to remove as much token from + # left and right in the first tensor. If z is a tuple, all its + # elements are trimed according to the triming for the first + def trim(self, z, token="#"): + n = self.token2id[token] + if type(z) == tuple: + x = z[0] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return tuple([t[:, a:b] for t in z]) + else: + i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return z[:, a:b] + + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + size, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.device = device + self.batch_size = batch_size + self.grid_factory = grid.GridFactory(size=size) + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + self.train_descr = self.grid_factory.generate_samples( + nb_train_samples, lambda r: tqdm.tqdm(r) + ) + self.test_descr = self.grid_factory.generate_samples( + nb_test_samples, lambda r: tqdm.tqdm(r) + ) + + # Build the tokenizer + tokens = set() + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + tokens = ["#"] + tokens + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_nul = self.token2id["#"] + self.t_true = self.token2id["true"] + self.t_false = self.token2id["false"] + + # Tokenize the train and test sets + self.train_input = self.str2tensor(self.train_descr) + self.test_input = self.str2tensor(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + correct = self.test_input[:1000] + result = correct.clone() + ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long() + result *= 1 - ar_mask # paraaaaanoiaaaaaaa + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_before {e}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_after {e}") + + logger(f"----------------------------------------------------------") + + nb_total = ar_mask.sum().item() + nb_correct = ((correct == result).long() * ar_mask).sum().item() + + logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") + logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") + + +###################################################################### + +import qmlp + + +class QMLP(Task): + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + result_dir, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.device = device + self.batch_size = batch_size + self.nb_samples_per_mlp = 256 + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set( + nb_mlps=nb_train_samples + nb_test_samples, + nb_samples=self.nb_samples_per_mlp, + device=self.device, + batch_size=64, + nb_epochs=250, + nb_mlps_per_batch=1024, + ) + + self.train_input = seq[:nb_train_samples] + self.train_q_test_set = q_test_set[:nb_train_samples] + self.train_ref_test_errors = test_error[:nb_train_samples] + self.test_input = seq[nb_train_samples:] + self.test_q_test_set = q_test_set[nb_train_samples:] + self.test_ref_test_errors = test_error[nb_train_samples:] + + filename = os.path.join(result_dir, f"train_errors_ref.dat") + with open(filename, "w") as f: + for e in self.train_ref_test_errors: + f.write(f"{e}\n") + + filename = os.path.join(result_dir, f"test_errors_ref.dat") + with open(filename, "w") as f: + for e in self.test_ref_test_errors: + f.write(f"{e}\n") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + correct = self.test_input[:1000] + result = correct.clone() + ar_mask = ( + torch.arange(result.size(1), device=result.device) + > self.nb_samples_per_mlp * 3 + 1 + ).long()[None, :] + ar_mask = ar_mask.expand_as(result) + result *= 1 - ar_mask # paraaaaanoiaaaaaaa + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + q_train_set = result[:, : self.nb_samples_per_mlp * 3] + q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :] + error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set) + + filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat") + with open(filename, "w") as f: + for e in error_test: + f.write(f"{e}\n") + + +###################################################################### diff --git a/world.py b/world.py new file mode 100755 index 0000000..aad0bfb --- /dev/null +++ b/world.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, tqdm + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F +import cairo + +###################################################################### + + +class Box: + nb_rgb_levels = 10 + + def __init__(self, x, y, w, h, r, g, b): + self.x = x + self.y = y + self.w = w + self.h = h + self.r = r + self.g = g + self.b = b + + def collision(self, scene): + for c in scene: + if ( + self is not c + and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w) + and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h) + ): + return True + return False + + +###################################################################### + + +class Normalizer(nn.Module): + def __init__(self, mu, std): + super().__init__() + self.register_buffer("mu", mu) + self.register_buffer("log_var", 2 * torch.log(std)) + + def forward(self, x): + return (x - self.mu) / torch.exp(self.log_var / 2.0) + + +class SignSTE(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # torch.sign() takes three values + s = (x >= 0).float() * 2 - 1 + + if self.training: + u = torch.tanh(x) + return s + u - u.detach() + else: + return s + + +class DiscreteSampler2d(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + s = (x >= x.max(-3, keepdim=True).values).float() + + if self.training: + u = x.softmax(dim=-3) + return s + u - u.detach() + else: + return s + + +def loss_H(binary_logits, h_threshold=1): + p = binary_logits.sigmoid().mean(0) + h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2) + h.clamp_(max=h_threshold) + return h_threshold - h.mean() + + +def train_encoder( + train_input, + test_input, + depth, + nb_bits_per_token, + dim_hidden=48, + lambda_entropy=0.0, + lr_start=1e-3, + lr_end=1e-4, + nb_epochs=10, + batch_size=25, + logger=None, + device=torch.device("cpu"), +): + mu, std = train_input.float().mean(), train_input.float().std() + + def encoder_core(depth, dim): + l = [ + [ + nn.Conv2d( + dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2 + ), + nn.ReLU(), + nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2), + nn.ReLU(), + ] + for k in range(depth) + ] + + return nn.Sequential(*[x for m in l for x in m]) + + def decoder_core(depth, dim): + l = [ + [ + nn.ConvTranspose2d( + dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2 + ), + nn.ReLU(), + nn.ConvTranspose2d( + dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2 + ), + nn.ReLU(), + ] + for k in range(depth - 1, -1, -1) + ] + + return nn.Sequential(*[x for m in l for x in m]) + + encoder = nn.Sequential( + Normalizer(mu, std), + nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1), + nn.ReLU(), + # 64x64 + encoder_core(depth=depth, dim=dim_hidden), + # 8x8 + nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1), + ) + + quantizer = SignSTE() + + decoder = nn.Sequential( + nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1), + # 8x8 + decoder_core(depth=depth, dim=dim_hidden), + # 64x64 + nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1), + ) + + model = nn.Sequential(encoder, decoder) + + nb_parameters = sum(p.numel() for p in model.parameters()) + + logger(f"vqae nb_parameters {nb_parameters}") + + model.to(device) + + for k in range(nb_epochs): + lr = math.exp( + math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k + ) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + acc_train_loss = 0.0 + + for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): + input = input.to(device) + z = encoder(input) + zq = quantizer(z) + output = decoder(zq) + + output = output.reshape( + output.size(0), -1, 3, output.size(2), output.size(3) + ) + + train_loss = F.cross_entropy(output, input) + + if lambda_entropy > 0: + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) + + acc_train_loss += train_loss.item() * input.size(0) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + acc_test_loss = 0.0 + + for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"): + input = input.to(device) + z = encoder(input) + zq = quantizer(z) + output = decoder(zq) + + output = output.reshape( + output.size(0), -1, 3, output.size(2), output.size(3) + ) + + test_loss = F.cross_entropy(output, input) + + acc_test_loss += test_loss.item() * input.size(0) + + train_loss = acc_train_loss / train_input.size(0) + test_loss = acc_test_loss / test_input.size(0) + + logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") + sys.stdout.flush() + + return encoder, quantizer, decoder + + +###################################################################### + + +def scene2tensor(xh, yh, scene, size): + width, height = size, size + pixel_map = torch.ByteTensor(width, height, 4).fill_(255) + data = pixel_map.numpy() + surface = cairo.ImageSurface.create_for_data( + data, cairo.FORMAT_ARGB32, width, height + ) + + ctx = cairo.Context(surface) + ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD) + + for b in scene: + ctx.move_to(b.x * size, b.y * size) + ctx.rel_line_to(b.w * size, 0) + ctx.rel_line_to(0, b.h * size) + ctx.rel_line_to(-b.w * size, 0) + ctx.close_path() + ctx.set_source_rgba( + b.r / (Box.nb_rgb_levels - 1), + b.g / (Box.nb_rgb_levels - 1), + b.b / (Box.nb_rgb_levels - 1), + 1.0, + ) + ctx.fill() + + hs = size * 0.1 + ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0) + ctx.move_to(xh * size - hs / 2, yh * size - hs / 2) + ctx.rel_line_to(hs, 0) + ctx.rel_line_to(0, hs) + ctx.rel_line_to(-hs, 0) + ctx.close_path() + ctx.fill() + + return ( + pixel_map[None, :, :, :3] + .flip(-1) + .permute(0, 3, 1, 2) + .long() + .mul(Box.nb_rgb_levels) + .floor_divide(256) + ) + + +def random_scene(nb_insert_attempts=3): + scene = [] + colors = [ + ((Box.nb_rgb_levels - 1), 0, 0), + (0, (Box.nb_rgb_levels - 1), 0), + (0, 0, (Box.nb_rgb_levels - 1)), + ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0), + ( + (Box.nb_rgb_levels * 2) // 3, + (Box.nb_rgb_levels * 2) // 3, + (Box.nb_rgb_levels * 2) // 3, + ), + ] + + for k in range(nb_insert_attempts): + wh = torch.rand(2) * 0.2 + 0.2 + xy = torch.rand(2) * (1 - wh) + c = colors[torch.randint(len(colors), (1,))] + b = Box( + xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2] + ) + if not b.collision(scene): + scene.append(b) + + return scene + + +def generate_episode(steps, size=64): + delta = 0.1 + effects = [ + (False, 0, 0), + (False, delta, 0), + (False, 0, delta), + (False, -delta, 0), + (False, 0, -delta), + (True, delta, 0), + (True, 0, delta), + (True, -delta, 0), + (True, 0, -delta), + ] + + while True: + frames = [] + + scene = random_scene() + xh, yh = tuple(x.item() for x in torch.rand(2)) + + actions = torch.randint(len(effects), (len(steps),)) + nb_changes = 0 + + for s, a in zip(steps, actions): + if s: + frames.append(scene2tensor(xh, yh, scene, size=size)) + + grasp, dx, dy = effects[a] + + if grasp: + for b in scene: + if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh: + x, y = b.x, b.y + b.x += dx + b.y += dy + if ( + b.x < 0 + or b.y < 0 + or b.x + b.w > 1 + or b.y + b.h > 1 + or b.collision(scene) + ): + b.x, b.y = x, y + else: + xh += dx + yh += dy + nb_changes += 1 + else: + x, y = xh, yh + xh += dx + yh += dy + if xh < 0 or xh > 1 or yh < 0 or yh > 1: + xh, yh = x, y + + if nb_changes > len(steps) // 3: + break + + return frames, actions + + +###################################################################### + + +def generate_episodes(nb, steps): + all_frames, all_actions = [], [] + for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"): + frames, actions = generate_episode(steps) + all_frames += frames + all_actions += [actions[None, :]] + return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0) + + +def create_data_and_processors( + nb_train_samples, + nb_test_samples, + mode, + nb_steps, + depth=3, + nb_bits_per_token=8, + nb_epochs=10, + device=torch.device("cpu"), + device_storage=torch.device("cpu"), + logger=None, +): + assert mode in ["first_last"] + + if mode == "first_last": + steps = [True] + [False] * (nb_steps + 1) + [True] + + if logger is None: + logger = lambda s: print(s) + + train_input, train_actions = generate_episodes(nb_train_samples, steps) + train_input, train_actions = train_input.to(device_storage), train_actions.to( + device_storage + ) + test_input, test_actions = generate_episodes(nb_test_samples, steps) + test_input, test_actions = test_input.to(device_storage), test_actions.to( + device_storage + ) + + encoder, quantizer, decoder = train_encoder( + train_input, + test_input, + depth=depth, + nb_bits_per_token=nb_bits_per_token, + lambda_entropy=1.0, + nb_epochs=nb_epochs, + logger=logger, + device=device, + ) + encoder.train(False) + quantizer.train(False) + decoder.train(False) + + z = encoder(train_input[:1].to(device)) + pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :] + z_h, z_w = z.size(2), z.size(3) + + logger(f"vqae input {train_input[0].size()} output {z[0].size()}") + + def frame2seq(input, batch_size=25): + seq = [] + p = pow2.to(device) + for x in input.split(batch_size): + x = x.to(device) + z = encoder(x) + ze_bool = (quantizer(z) >= 0).long() + output = ( + ze_bool.permute(0, 2, 3, 1).reshape( + ze_bool.size(0), -1, ze_bool.size(1) + ) + * p + ).sum(-1) + + seq.append(output) + + return torch.cat(seq, dim=0) + + def seq2frame(input, batch_size=25, T=1e-2): + frames = [] + p = pow2.to(device) + for seq in input.split(batch_size): + seq = seq.to(device) + zd_bool = (seq[:, :, None] // p) % 2 + zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2) + logits = decoder(zd_bool * 2.0 - 1.0) + logits = logits.reshape( + logits.size(0), -1, 3, logits.size(2), logits.size(3) + ).permute(0, 2, 3, 4, 1) + output = torch.distributions.categorical.Categorical( + logits=logits / T + ).sample() + + frames.append(output) + + return torch.cat(frames, dim=0) + + return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame + + +###################################################################### + +if __name__ == "__main__": + ( + train_input, + train_actions, + test_input, + test_actions, + frame2seq, + seq2frame, + ) = create_data_and_processors( + 25000, + 1000, + nb_epochs=5, + mode="first_last", + nb_steps=20, + ) + + input = test_input[:256] + + seq = frame2seq(input) + output = seq2frame(seq) + + torchvision.utils.save_image( + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 + ) + + torchvision.utils.save_image( + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 + ) -- 2.39.5