From: François Fleuret Date: Sat, 6 Jan 2024 09:56:07 +0000 (+0100) Subject: Initial commit X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git Initial commit --- 4395f9a90218819997c706de9505cda1c86ad507 diff --git a/expr.py b/expr.py new file mode 100755 index 0000000..685efd3 --- /dev/null +++ b/expr.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, re + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + + +def random_var(nb_variables=None, variables=None): + if variables is None: + return chr(ord("A") + torch.randint(nb_variables, (1,)).item()) + else: + l = list(variables) + return l[torch.randint(len(l), (1,)).item()] + + +def random_expr(variables, operand_max, budget): + if budget <= 5: + op = torch.randint(2, (1,)).item() + if op == 0 and len(variables) > 0: + return random_var(variables=variables) + else: + return str(torch.randint(operand_max + 1, (1,)).item()) + else: + op = torch.randint(3, (1,)).item() + if op == 0: + e = random_expr(variables, operand_max, budget - 2) + if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"): + return "(" + e + ")" + else: + return e + else: + b = 2 + torch.randint(budget - 5, (1,)).item() + e1 = random_expr(variables, operand_max, b) + e2 = random_expr(variables, operand_max, budget - b - 1) + if op == 1: + return e1 + "+" + e2 + elif op == 2: + return e1 + "*" + e2 + + +def generate_program(nb_variables, operand_max, length): + s = "" + variables = set() + + while len(s) < length: + v = random_var(nb_variables=nb_variables) + s += v + "=" + random_expr(variables, operand_max, budget=20) + ";" + variables.add(v) + + return s, variables + + +def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99): + assert nb_variables <= 26 + sequences = [] + + for n in range(nb): + # We take length itself half of the time, and uniform between + # 1 and length otherwise. The actual length can be slightly + # greater + + l = min(length, 1 + torch.randint(length * 2, (1,)).item()) + result = None + while result == None or max(result.values()) > result_max: + p, v = generate_program(nb_variables, operand_max, l) + v = ", ".join(['"' + v + '": ' + v for v in v]) + ldict = {} + exec(p + "result={" + v + "}", globals(), ldict) + result = ldict["result"] + + k = list(result.keys()) + k.sort() + sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k])) + + return sequences + + +def extract_results(seq): + f = lambda a: (a[0], -1 if a[1] == "" else int(a[1])) + results = [ + dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)]) + for s in seq + ] + return results + + +if __name__ == "__main__": + import time + + start_time = time.perf_counter() + sequences = generate_sequences(1000, length=40) + end_time = time.perf_counter() + for s in sequences[:10]: + print(s) + print(f"{len(sequences) / (end_time - start_time):.02f} samples per second") + + print(extract_results(sequences[:10])) diff --git a/ffutils.py b/ffutils.py new file mode 100755 index 0000000..23952e5 --- /dev/null +++ b/ffutils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch +import sys, contextlib + +import torch +from torch import Tensor + +###################################################################### + + +@contextlib.contextmanager +def evaluation(*models): + with torch.inference_mode(): + t = [(m, m.training) for m in models] + for m in models: + m.train(False) + yield + for m, u in t: + m.train(u) + + +###################################################################### + +from torch.utils._python_dispatch import TorchDispatchMode + + +def hasNaN(x): + if torch.is_tensor(x): + return x.numel() > 0 and x.isnan().max() + else: + try: + return any([hasNaN(y) for y in x]) + except TypeError: + return False + + +class NaNDetect(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + + if hasNaN(res): + raise RuntimeError( + f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" + ) + return res + + +###################################################################### + + +def exception_hook(exc_type, exc_value, tb): + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. + + """ + + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" + + while tb: + print("--------------------------------------------------\n") + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + print(f' File "{filename}", line {line_no}, in {name}') + print(open(filename, "r").readlines()[line_no - 1]) + + if exc_type in {RuntimeError, ValueError, IndexError, TypeError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") + + print() + tb = tb.tb_next + + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") + + +def activate_tensorstack(): + sys.excepthook = exception_hook + + +###################################################################### + +if __name__ == "__main__": + import torch + + def dummy(a, b): + print(a @ b) + + def blah(a, b): + c = b + b + dummy(a, c) + + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm) diff --git a/graph.py b/graph.py new file mode 100755 index 0000000..07e376a --- /dev/null +++ b/graph.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +import cairo + + +###################################################################### + + +def save_attention_image( + # image to save + filename, + tokens_input, + tokens_output, + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, + # do not draw links with a lesser attention + min_link_attention=0, + # draw only the strongest links necessary so that their summed + # attention is above min_total_attention + min_total_attention=None, + # draw only the top k links + k_top=None, + # the purely graphical settings + curved=True, + pixel_scale=8, + token_gap=15, + layer_gap=25, + y_eps=0.5, + padding=10, +): + if k_top is not None: + am = [] + for m in attention_matrices: + am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) + attention_matrices = am + + if min_total_attention is not None: + am = [] + for m in attention_matrices: + s = m.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() + b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) + am.append(m * b) + + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) + + ctx = cairo.Context(surface) + ctx.scale(pixel_scale, pixel_scale) + + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.set_font_size(4.0) + # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + + x, y = 0, 0 + + ctx.set_line_width(0.25) + for d in range(len(attention_matrices)): + at = attention_matrices[d].to("cpu") + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_link_attention: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ax, ay = j * token_gap, y - y_eps + ctx.move_to(ax, ay) + dx, dy = i * token_gap, y - layer_gap + y_eps + if curved: + bx, by = ax, ay - layer_gap * 0.5 + cx, cy = dx, dy + layer_gap * 0.5 + ctx.curve_to(bx, by, cx, cy, dx, dy) + else: + ctx.line_to(dx, dy) + ctx.stroke() + y -= layer_gap + + for d in range(0, len(attention_matrices) + 1): + n = ( + attention_matrices[0].size(-1) + if d == 0 + else attention_matrices[d - 1].size(-2) + ) + for n in range(n): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to( + k * token_gap - width_t / 2, + -token_gap / 5 - len(attention_matrices) * layer_gap, + ) + ctx.show_text(s) + + x, y, width, height = surface.ink_extents() + x -= padding + y -= padding + width += 2 * padding + height += 2 * padding + pdf_surface = cairo.PDFSurface(filename, width, height) + ctx_pdf = cairo.Context(pdf_surface) + ctx_pdf.set_source_surface(surface, -x, -y) + ctx_pdf.paint() + pdf_surface.finish() + + +###################################################################### + +if __name__ == "__main__": + import mygpt + + tokens_output = ["", "-", 3, 4, ""] + tokens_input = [""] + tokens_output[:-1] + + vocabulary_size = 3 + x = torch.randint(vocabulary_size, (1, len(tokens_input))) + + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=4, + dim_keys=2, + dim_hidden=2, + nb_heads=2, + nb_blocks=5, + dropout=0.1, + causal=True, + ) + + model.eval() + model.record_attention() + + y1 = model(mygpt.BracketedSequence(x)).x + + attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] + + save_attention_image( + "attention.pdf", + tokens_input, + tokens_output, + attention_matrices, + # k_top=2, + min_total_attention=0.9, + ) diff --git a/grid.py b/grid.py new file mode 100755 index 0000000..268f4ee --- /dev/null +++ b/grid.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math +import torch, torchvision +import torch.nn.functional as F + +name_shapes = ["A", "B", "C", "D", "E", "F"] + +name_colors = ["red", "yellow", "blue", "green", "white", "purple"] + +###################################################################### + + +class GridFactory: + def __init__( + self, + size=6, + max_nb_items=4, + max_nb_transformations=3, + nb_questions=4, + ): + assert size % 2 == 0 + self.size = size + self.max_nb_items = max_nb_items + self.max_nb_transformations = max_nb_transformations + self.nb_questions = nb_questions + + def generate_scene(self): + nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 + col = torch.full((self.size * self.size,), -1) + shp = torch.full((self.size * self.size,), -1) + a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] + col[:nb_items] = a % len(name_colors) + shp[:nb_items] = a // len(name_colors) + i = torch.randperm(self.size * self.size) + col = col[i] + shp = shp[i] + return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) + + def random_transformations(self, scene): + col, shp = scene + + descriptions = [] + nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() + transformations = torch.randint(5, (nb_transformations,)) + + for t in transformations: + if t == 0: + col, shp = col.flip(0), shp.flip(0) + descriptions += [" vertical flip"] + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + descriptions += [" horizontal flip"] + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + descriptions += [" rotate 90 degrees"] + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + descriptions += [" rotate 180 degrees"] + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + descriptions += [" rotate 270 degrees"] + + col, shp = col.contiguous(), shp.contiguous() + + return (col, shp), descriptions + + def print_scene(self, scene): + col, shp = scene + + # for i in range(self.size): + # for j in range(self.size): + # if col[i,j] >= 0: + # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + + for i in range(self.size): + for j in range(self.size): + if col[i, j] >= 0: + print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + elif j == 0: + print(" +", end="") + else: + print("-+", end="") + if j < self.size - 1: + print("--", end="") + else: + print("") + if i < self.size - 1: + for j in range(self.size - 1): + print(" | ", end="") + print(" |") + + def grid_positions(self, scene): + col, shp = scene + + properties = [] + + for i in range(self.size): + for j in range(self.size): + if col[i, j] >= 0: + n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + properties += [f"a {n} at {i} {j}"] + + return properties + + def all_properties(self, scene): + col, shp = scene + + properties = [] + + for i1 in range(self.size): + for j1 in range(self.size): + if col[i1, j1] >= 0: + n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + properties += [f"there is a {n1}"] + if i1 < self.size // 2: + properties += [f"a {n1} is in the top half"] + if i1 >= self.size // 2: + properties += [f"a {n1} is in the bottom half"] + if j1 < self.size // 2: + properties += [f"a {n1} is in the left half"] + if j1 >= self.size // 2: + properties += [f"a {n1} is in the right half"] + for i2 in range(self.size): + for j2 in range(self.size): + if col[i2, j2] >= 0: + n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + if i1 > i2: + properties += [f"a {n1} is below a {n2}"] + if i1 < i2: + properties += [f"a {n1} is above a {n2}"] + if j1 > j2: + properties += [f"a {n1} is right of a {n2}"] + if j1 < j2: + properties += [f"a {n1} is left of a {n2}"] + if abs(i1 - i2) + abs(j1 - j2) == 1: + properties += [f"a {n1} is next to a {n2}"] + + return properties + + def generate_scene_and_questions(self): + while True: + while True: + start_scene = self.generate_scene() + scene, transformations = self.random_transformations(start_scene) + true = self.all_properties(scene) + if len(true) >= self.nb_questions: + break + + for a in range(10): + col, shp = scene + col, shp = col.view(-1), shp.view(-1) + p = torch.randperm(col.size(0)) + col, shp = col[p], shp[p] + other_scene = ( + col.view(self.size, self.size), + shp.view(self.size, self.size), + ) + + false = self.all_properties(other_scene) + + # We sometime add properties from a totally different + # scene to have negative "there is a xxx xxx" + # properties + if torch.rand(1).item() < 0.2: + other_scene = self.generate_scene() + false += self.all_properties(other_scene) + + false = list(set(false) - set(true)) + if len(false) >= self.nb_questions: + break + + if a < 10: + break + + true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] + false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] + true = [" " + q + " true" for q in true] + false = [" " + q + " false" for q in false] + + union = true + false + questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] + + result = " ".join( + [" " + x for x in self.grid_positions(start_scene)] + + transformations + + questions + ) + + return start_scene, scene, result + + def generate_samples(self, nb, progress_bar=None): + result = [] + + r = range(nb) + if progress_bar is not None: + r = progress_bar(r) + + for _ in r: + result.append(self.generate_scene_and_questions()[2]) + + return result + + +###################################################################### + +if __name__ == "__main__": + import time + + grid_factory = GridFactory() + + # start_time = time.perf_counter() + # samples = grid_factory.generate_samples(10000) + # end_time = time.perf_counter() + # print(f"{len(samples) / (end_time - start_time):.02f} samples per second") + + start_scene, scene, questions = grid_factory.generate_scene_and_questions() + print() + print("-- Original scene -----------------------------") + print() + grid_factory.print_scene(start_scene) + print() + print("-- Transformed scene --------------------------") + print() + grid_factory.print_scene(scene) + print() + print("-- Sequence -----------------------------------") + print() + print(questions) + +###################################################################### diff --git a/main.py b/main.py new file mode 100755 index 0000000..df46652 --- /dev/null +++ b/main.py @@ -0,0 +1,912 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, argparse, time, tqdm, os, datetime, warnings + +import torch, torchvision +from torch import nn +from torch.nn import functional as F + +import ffutils +import mygpt, tasks, problems + +###################################################################### + +if torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + +parser = argparse.ArgumentParser( + description="An implementation of GPT with cache.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) + +parser.add_argument( + "--task", + type=str, + default="twotargets", + help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp", +) + +parser.add_argument("--log_filename", type=str, default="train.log", help=" ") + +parser.add_argument("--result_dir", type=str, default=None) + +parser.add_argument("--seed", type=int, default=0) + +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + +parser.add_argument("--nb_epochs", type=int, default=50) + +parser.add_argument("--batch_size", type=int, default=None) + +parser.add_argument("--nb_train_samples", type=int, default=None) + +parser.add_argument("--nb_test_samples", type=int, default=None) + +parser.add_argument("--optim", type=str, default="adam") + +######################################## + +parser.add_argument("--nb_warmup_iter", type=int, default=100) + +parser.add_argument("--nb_decay_iter", type=int, default=5000) + +parser.add_argument("--learning_rate", type=float, default=6e-4) + +parser.add_argument("--min_learning_rate", type=float, default=6e-5) + +######################################## + +parser.add_argument("--model", type=str, default=None) + +parser.add_argument("--attention", type=str, default=None) + +parser.add_argument("--dim_model", type=int, default=None) + +parser.add_argument("--dim_keys", type=int, default=None) + +parser.add_argument("--dim_hidden", type=int, default=None) + +parser.add_argument("--nb_heads", type=int, default=None) + +parser.add_argument("--nb_lines", type=int, default=None) + +parser.add_argument("--caterpillar_height", type=int, default=None) + +parser.add_argument("--rho", type=float, default=0.0) + +parser.add_argument("--dim_rec_v", type=int, default=None) + +parser.add_argument("--nb_blocks", type=int, default=None) + +parser.add_argument("--dropout", type=float, default=0.1) + +######################################## + +parser.add_argument("--deterministic_synthesis", action="store_true", default=False) + +parser.add_argument("--no_checkpoint", action="store_true", default=False) + +parser.add_argument("--overwrite_results", action="store_true", default=False) + +parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") + +############################## +# rpl options + +parser.add_argument("--rpl_nb_starting_values", type=int, default=3) + +parser.add_argument("--rpl_max_input", type=int, default=9) + +parser.add_argument("--rpl_prog_len", type=int, default=8) + +parser.add_argument("--rpl_nb_runs", type=int, default=5) + +parser.add_argument("--rpl_no_prog", action="store_true", default=False) + +############################## +# grid options + +parser.add_argument("--grid_size", type=int, default=6) + +############################## +# picoclvr options + +parser.add_argument("--picoclvr_nb_colors", type=int, default=5) + +parser.add_argument("--picoclvr_height", type=int, default=12) + +parser.add_argument("--picoclvr_width", type=int, default=16) + +parser.add_argument("--picocvlr_prune_properties", type=str, default="none") + +############################## +# Maze options + +parser.add_argument("--maze_height", type=int, default=13) + +parser.add_argument("--maze_width", type=int, default=21) + +parser.add_argument("--maze_nb_walls", type=int, default=15) + +############################## +# Snake options + +parser.add_argument("--snake_height", type=int, default=9) + +parser.add_argument("--snake_width", type=int, default=12) + +parser.add_argument("--snake_nb_colors", type=int, default=5) + +parser.add_argument("--snake_length", type=int, default=200) + +############################## +# Stack options + +parser.add_argument("--stack_nb_steps", type=int, default=100) + +parser.add_argument("--stack_nb_stacks", type=int, default=3) + +parser.add_argument("--stack_nb_digits", type=int, default=3) + +parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75) + +############################## +# Expr options + +parser.add_argument("--expr_nb_variables", type=int, default=5) + +parser.add_argument("--expr_sequence_length", type=int, default=40) + +parser.add_argument("--expr_operand_max", type=int, default=9) + +parser.add_argument("--expr_result_max", type=int, default=99) + +parser.add_argument("--expr_input_file", type=str, default=None) + +############################## +# Memory + +parser.add_argument("--memory_len_total", type=int, default=32) + +############################## +# Mixing + +parser.add_argument("--mixing_hard", action="store_true", default=False) + +parser.add_argument("--mixing_deterministic_start", action="store_true", default=False) + +###################################################################### + +args = parser.parse_args() + +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + +if args.result_dir is None: + args.result_dir = f"results_{args.task}_{args.model}" + +###################################################################### + +default_task_args = { + "addition": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "byheart": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "expr": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 2500000, + "nb_test_samples": 10000, + }, + "grid": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "qmlp": { + "model": "37M", + "batch_size": 10, + "nb_train_samples": 100000, + "nb_test_samples": 1000, + }, + "guessop": { + "model": "352M", + "batch_size": 25, + "nb_train_samples": 1000000, + "nb_test_samples": 10000, + }, + "learnop": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "maze": { + "model": "37M", + "batch_size": 5, + "nb_train_samples": 100000, + "nb_test_samples": 10000, + }, + "picoclvr": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "rpl": { + "model": "352M", + "batch_size": 5, + "nb_train_samples": 2500000, + "nb_test_samples": 10000, + }, + "snake": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "stack": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 100000, + "nb_test_samples": 1000, + }, + "twotargets": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 50000, + "nb_test_samples": 10000, + }, + "memory": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 25000, + "nb_test_samples": 10000, + }, + "mixing": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "mnist": { + "model": "37M", + "batch_size": 10, + "nb_train_samples": 60000, + "nb_test_samples": 10000, + }, +} + +if args.task in default_task_args: + for k, v in default_task_args[args.task].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + +default_model_args = { + "17K": { + "attention": "mha", + "dim_model": 32, + "dim_keys": 32, + "dim_hidden": 32, + "nb_heads": 2, + "dim_rec_v": 16, + "nb_blocks": 2, + }, + "17K-C": { + "attention": "caterpillar", + "dim_model": 32, + "dim_keys": 32, + "dim_hidden": 32, + "nb_heads": 2, + "nb_lines": 16, + "caterpillar_height": 4, + "dim_rec_v": 16, + "nb_blocks": 2, + }, + "4M": { + "attention": "mha", + "dim_model": 256, + "dim_keys": 32, + "dim_hidden": 1024, + "nb_heads": 4, + "dim_rec_v": 64, + "nb_blocks": 6, + }, + "4M-C": { + "attention": "caterpillar", + "dim_model": 256, + "dim_keys": 32, + "dim_hidden": 1024, + "nb_heads": 4, + "nb_lines": 32, + "caterpillar_height": 4, + "dim_rec_v": 64, # dim_model / nb_heads + "nb_blocks": 6, + }, + "37M": { + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 64, + "nb_blocks": 12, + }, + "37M-C": { + "attention": "caterpillar", + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 256, + "caterpillar_height": 32, + "dim_rec_v": 64, + "nb_blocks": 12, + }, + "122M": { + "attention": "mha", + "dim_model": 768, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 96, + "nb_blocks": 24, + }, + "122M-C": { + "attention": "caterpillar", + "dim_model": 768, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 128, + "dim_rec_v": 96, + "nb_blocks": 24, + }, + "352M": { + "attention": "mha", + "dim_model": 1024, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "dim_rec_v": 128, + "nb_blocks": 48, + }, + "352M-C": { + "attention": "caterpillar", + "dim_model": 1024, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_lines": 128, + "dim_rec_v": 128, + "nb_blocks": 48, + }, +} + +if args.model in default_model_args: + for k, v in default_model_args[args.model].items(): + if getattr(args, k) is None: + setattr(args, k, v) +else: + raise ValueError(f"Unknown model {args.model}") + +###################################################################### + +try: + os.mkdir(args.result_dir) +except FileExistsError: + if not args.overwrite_results: + print(f"result directory {args.result_dir} already exists") + exit(1) + +log_file = open(os.path.join(args.result_dir, args.log_filename), "a") + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +with os.popen("sha256sum *.py") as f: + for l in f: + log_string(f"sha256sum {l.strip()}") + +now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") + +log_string(f"argv {' '.join(sys.argv)}") + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + + +###################################################################### + +# from nanoGPT + + +def get_lr(it): + # 1) linear warmup for warmup_iter steps + if it < args.nb_warmup_iter: + return args.learning_rate * it / args.nb_warmup_iter + # 2) if it > nb_decay_iter, return min learning rate + if it > args.nb_decay_iter: + return args.min_learning_rate + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.nb_warmup_iter) / ( + args.nb_decay_iter - args.nb_warmup_iter + ) + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return args.min_learning_rate + coeff * ( + args.learning_rate - args.min_learning_rate + ) + + +###################################################################### + + +def picoclvr_pruner_horizontal_green(p): + return not ("green" in p and ("left" in p or "right" in p)) + + +picoclvr_pruner_train = ( + picoclvr_pruner_horizontal_green + if args.picocvlr_prune_properties in {"train+eval"} + else None +) + +picoclvr_pruner_eval = ( + (lambda p: not picoclvr_pruner_horizontal_green(p)) + if args.picocvlr_prune_properties in {"train+eval", "eval"} + else None +) + +###################################################################### + +device_data = device + +if args.task == "byheart": + task = tasks.SandBox( + problem=problems.ProblemByHeart(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + args.max_percents_of_test_in_train = -1 + +elif args.task == "learnop": + task = tasks.SandBox( + problem=problems.ProblemLearnOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + + +elif args.task == "guessop": + task = tasks.SandBox( + problem=problems.ProblemGuessOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + + +elif args.task == "twotargets": + task = tasks.SandBox( + problem=problems.ProblemTwoTargets(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "memory": + task = tasks.SandBox( + problem=problems.ProblemMemory(len_total=args.memory_len_total), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "mixing": + task = tasks.SandBox( + problem=problems.ProblemMixing( + hard=args.mixing_hard, random_start=not args.mixing_deterministic_start + ), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "addition": + task = tasks.SandBox( + problem=problems.ProblemAddition(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "picoclvr": + task = tasks.PicoCLVR( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.picoclvr_height, + width=args.picoclvr_width, + nb_colors=args.picoclvr_nb_colors, + logger=log_string, + device=device_data, + pruner_train=picoclvr_pruner_train, + pruner_eval=picoclvr_pruner_eval, + ) + +elif args.task == "mnist": + task = tasks.MNIST( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + device=device_data, + ) + +elif args.task == "maze": + task = tasks.Maze( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, + device=device_data, + ) + +elif args.task == "snake": + task = tasks.Snake( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.snake_height, + width=args.snake_width, + nb_colors=args.snake_nb_colors, + length=args.snake_length, + prompt_length=args.snake_length // 2, + device=device_data, + ) + +elif args.task == "stack": + task = tasks.Stack( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + nb_steps=args.stack_nb_steps, + nb_stacks=args.stack_nb_stacks, + nb_digits=args.stack_nb_digits, + fraction_values_for_train=args.stack_fraction_values_for_train, + device=device_data, + ) + +elif args.task == "expr": + task = tasks.Expr( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + nb_variables=args.expr_nb_variables, + sequence_length=args.expr_sequence_length, + operand_max=args.expr_operand_max, + result_max=args.expr_result_max, + batch_size=args.batch_size, + device=device_data, + ) + +elif args.task == "rpl": + task = tasks.RPL( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + nb_starting_values=args.rpl_nb_starting_values, + max_input=args.rpl_max_input, + prog_len=args.rpl_prog_len, + nb_runs=args.rpl_nb_runs, + no_prog=args.rpl_no_prog, + logger=log_string, + device=device_data, + ) + +elif args.task == "grid": + task = tasks.Grid( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + size=args.grid_size, + logger=log_string, + device=device_data, + ) + +elif args.task == "qmlp": + task = tasks.QMLP( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + result_dir=args.result_dir, + logger=log_string, + device=device_data, + ) + +else: + raise ValueError(f"Unknown task {args.task}") + +###################################################################### + +log_string(f"device {device}") + +vocabulary_size = task.vocabulary_size() + +log_string(f"vocabulary_size {vocabulary_size}") + +############################## + +model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_lines=args.nb_lines, + caterpillar_height=args.caterpillar_height, + dim_rec_v=args.dim_rec_v, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, + attention_layer=args.attention, +) + +model.to(device) + +nb_parameters = sum(p.numel() for p in model.parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") + +###################################################################### + +nb_epochs_finished = 0 + +if args.no_checkpoint: + log_string(f"not trying to load checkpoint.") + +else: + try: + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + checkpoint = torch.load(checkpoint_name) + nb_epochs_finished = checkpoint["nb_epochs_finished"] + model.load_state_dict(checkpoint["model_state"]) + torch.set_rng_state(checkpoint["rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) + + log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.") + + except FileNotFoundError: + log_string("starting from scratch.") + + except: + log_string("error when loading the checkpoint.") + exit(1) + +###################################################################### + +if args.task == "expr" and args.expr_input_file is not None: + task.produce_results( + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + input_file=args.expr_input_file, + ) + + exit(0) + +###################################################################### + +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +# Compute the entropy of the training tokens + +token_count = 0 +for input in task.batches(split="train"): + token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +###################################################################### +# A bit of paranoia never hurts + +if args.max_percents_of_test_in_train >= 0: + + def subsets_as_tuples(batches, cs): + s = set() + for batch in batches: + for x in batch: + s.add(tuple([v.item() for v in x])) + if len(s) == cs: + yield s + s = set() + yield s + + nb_test, nb_in_train = 0, 0 + for test_subset in subsets_as_tuples(task.batches(split="test"), 25000): + in_train = set() + for train_subset in subsets_as_tuples(task.batches(split="train"), 25000): + in_train.update(test_subset.intersection(train_subset)) + nb_in_train += len(in_train) + nb_test += len(test_subset) + + log_string( + f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" + ) + + assert ( + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 + ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" + +############################## + +nb_samples_seen = 0 + +if nb_epochs_finished >= nb_epochs: + task.produce_results( + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + ) + +time_pred_result = None + +it = 0 + +for n_epoch in range(nb_epochs_finished, nb_epochs): + if args.optim == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) + elif args.optim == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + elif args.optim == "adamw": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + else: + raise ValueError(f"Unknown optimizer {args.optim}.") + + model.train() + + nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 + + for input in task.batches(split="train"): + model.reset_inner_loss() + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + inner_loss = model.get_inner_loss() + + acc_train_loss += loss.item() * input.size(0) + acc_train_inner_loss += inner_loss.item() * input.size(0) + + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) + + total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + + it += 1 + lr = get_lr(it) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + # log_string(f"learning_rate {lr}") + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + with torch.autograd.no_grad(): + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in task.batches(split="test"): + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + log_string( + f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}" + ) + + task.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + ) + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) + + time_current_result = datetime.datetime.now() + if time_pred_result is not None: + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) + time_pred_result = time_current_result + + checkpoint = { + "nb_epochs_finished": n_epoch + 1, + "model_state": model.state_dict(), + "rng_state": torch.get_rng_state(), + } + + if torch.cuda.is_available(): + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() + + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + torch.save(checkpoint, checkpoint_name) + log_string(f"saved checkpoint {checkpoint_name}") + +###################################################################### diff --git a/maze.py b/maze.py new file mode 100755 index 0000000..8ac9fce --- /dev/null +++ b/maze.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision + +###################################################################### + +v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4 + + +def create_maze(h=11, w=17, nb_walls=8): + assert h % 2 == 1 and w % 2 == 1 + + a, k = 0, 0 + + while k < nb_walls: + while True: + if a == 0: + m = torch.zeros(h, w, dtype=torch.int64) + m[0, :] = 1 + m[-1, :] = 1 + m[:, 0] = 1 + m[:, -1] = 1 + + r = torch.rand(4) + + if r[0] <= 0.5: + i1, i2, j = ( + int((r[1] * h).item()), + int((r[2] * h).item()), + int((r[3] * w).item()), + ) + i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2 + i1, i2 = min(i1, i2), max(i1, i2) + if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1: + m[i1 : i2 + 1, j] = 1 + break + else: + i, j1, j2 = ( + int((r[1] * h).item()), + int((r[2] * w).item()), + int((r[3] * w).item()), + ) + i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2 + j1, j2 = min(j1, j2), max(j1, j2) + if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1: + m[i, j1 : j2 + 1] = 1 + break + a += 1 + + if a > 10 * nb_walls: + a, k = 0, 0 + + k += 1 + + return m + + +###################################################################### + + +def compute_distance(walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + d = ( + torch.cat( + ( + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d) + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + + +###################################################################### + + +def compute_policy(walls, goal_i, goal_j): + distance = compute_distance(walls, goal_i, goal_j) + distance = distance + walls.numel() * walls + + value = distance.new_full((4,) + distance.size(), walls.numel()) + value[0, :, 1:] = distance[:, :-1] # < + value[1, :, :-1] = distance[:, 1:] # > + value[2, 1:, :] = distance[:-1, :] # ^ + value[3, :-1, :] = distance[1:, :] # v + + proba = (value.min(dim=0)[0][None] == value).float() + proba = proba / proba.sum(dim=0)[None] + proba = proba * (1 - walls) + walls.float() / 4 + + return proba + + +def stationary_densities(mazes, policies): + policies = policies * (mazes != v_goal)[:, None] + start = (mazes == v_start).nonzero(as_tuple=True) + probas = mazes.new_zeros(mazes.size(), dtype=torch.float32) + pred_probas = probas.clone() + probas[start] = 1.0 + + while not pred_probas.equal(probas): + pred_probas.copy_(probas) + probas.zero_() + probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :] + probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :] + probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1] + probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:] + probas[start] = 1.0 + + return probas + + +###################################################################### + + +def mark_path(walls, i, j, goal_i, goal_j, policy): + action = torch.distributions.categorical.Categorical( + policy.permute(1, 2, 0) + ).sample() + n, nmax = 0, walls.numel() + while i != goal_i or j != goal_j: + di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]] + i, j = i + di, j + dj + assert walls[i, j] == 0 + walls[i, j] = v_path + n += 1 + assert n < nmax + + +def path_optimality(ref_paths, paths): + return (ref_paths == v_path).long().flatten(1).sum(1) == ( + paths == v_path + ).long().flatten(1).sum(1) + + +def path_correctness(mazes, paths): + still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum( + 1 + ) == 0 + reached = still_ok.new_zeros(still_ok.size()) + current, pred_current = paths.clone(), paths.new_zeros(paths.size()) + goal = (mazes == v_goal).long() + while not pred_current.equal(current): + pred_current.copy_(current) + u = (current == v_start).long() + possible_next = ( + u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0 + ).long() + u = u[:, 1:-1, 1:-1] + reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * ( + (current == v_path).sum((1, 2)) == 0 + ) + current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + ( + v_start - v_path + ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path)) + still_ok *= (current == v_start).sum((1, 2)) <= 1 + + return still_ok * reached + + +###################################################################### + + +def create_maze_data( + nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x +): + mazes = torch.empty(nb, height, width, dtype=torch.int64) + paths = torch.empty(nb, height, width, dtype=torch.int64) + policies = torch.empty(nb, 4, height, width) + + for n in progress_bar(range(nb)): + maze = create_maze(height, width, nb_walls) + i = (maze == v_empty).nonzero() + while True: + start, goal = i[torch.randperm(i.size(0))[:2]] + if (start - goal).abs().sum() >= dist_min: + break + start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1] + + policy = compute_policy(maze, goal_i, goal_j) + path = maze.clone() + mark_path(path, start_i, start_j, goal_i, goal_j, policy) + maze[start_i, start_j] = v_start + maze[goal_i, goal_j] = v_goal + path[start_i, start_j] = v_start + path[goal_i, goal_j] = v_goal + + mazes[n] = maze + paths[n] = path + policies[n] = policy + + return mazes, paths, policies + + +###################################################################### + + +def save_image( + name, + mazes, + target_paths=None, + predicted_paths=None, + path_correct=None, + path_optimal=None, +): + colors = torch.tensor( + [ + [255, 255, 255], # empty + [0, 0, 0], # wall + [0, 255, 0], # start + [127, 127, 255], # goal + [255, 0, 0], # path + ] + ) + + mazes = mazes.cpu() + + c_mazes = ( + colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) + ) + + imgs = c_mazes.unsqueeze(1) + + if target_paths is not None: + target_paths = target_paths.cpu() + + c_target_paths = ( + colors[target_paths.reshape(-1)] + .reshape(target_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + + imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1) + + if predicted_paths is not None: + predicted_paths = predicted_paths.cpu() + c_predicted_paths = ( + colors[predicted_paths.reshape(-1)] + .reshape(predicted_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1) + + img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1) + + # NxKxCxHxW + if path_optimal is not None: + path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1) + img = ( + img * (1 - path_optimal) + + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal + ) + + if path_correct is not None: + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * ( + 1 - path_correct + ) + + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + + print(f"{img.size()=} {imgs.size()=}") + + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) + + +###################################################################### + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mazes, paths, policies = create_maze_data(8) + mazes, paths = mazes.to(device), paths.to(device) + save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths) + print(path_correctness(mazes, paths)) + +###################################################################### diff --git a/memload.py b/memload.py new file mode 100755 index 0000000..5fcd089 --- /dev/null +++ b/memload.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension + +cpp_source = """ +std::vector greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) { + auto nb_lines = load_start.size(1); + auto batch_size = line_requests.size(0); + auto nb_heads = line_requests.size(1); + auto T = line_requests.size(2); + + auto load_start_a = load_start.accessor(); + auto line_requests_a = line_requests.accessor(); + + auto load = torch::empty({batch_size, nb_lines, T}); + auto load_a = load.accessor(); + + auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64)); + auto allocation_result_a = allocation_result.accessor(); + + for(int n = 0; n < batch_size; n++) { + for(int t = 0; t < T; t++) { + for(int l = 0; l < nb_lines; l++) { + if(t == 0) { + load[n][l][t] = decay * load_start_a[n][l]; + } else { + load[n][l][t] = decay * load[n][l][t-1]; + } + } + for(int h = 0; h < nb_heads; h++) { + if(line_requests_a[n][h][t] > 0) { + int l_lowest_load; + for(int l = 0; l < nb_lines; l++) { + if(l == 0 || load_a[n][l][t] + +# This is an implementation from scratch of a "GPT", that is a model +# composed of several causal self-attention blocks. It is equipped +# with a caching mechanism for keys and values to avoid a O(N^3) cost +# for auto-regression. + +import math, warnings + +import torch, einops + +from torch import nn +from torch.nn import functional as F +from functorch.dim import dims + +import ffutils + +# import memload + +###################################################################### + +# A BracketedSequence is a BxTx... tensor with a first and a nb time +# steps to compute. + +# Modules able to process it expect that they will have to process a +# first bracket starting at t=0, followed by a succession of brackets +# that move forward in time, do not overlap, and cover the axis T with +# no holes. +# +# Although it is more general, for a classical prompt-conditioned +# auto-regressive process it will be a first bracket starting at 0 and +# of arbitrary length for the "prompt", followed by brackets of length +# 1 for the successive tokens. +# +# Modules able to process brackets may implement a cache that is +# resetted when the input bracket starts at t=0 + + +class BracketedSequence: + def __init__(self, x, first=None, nb=None, init_cache=None): + self.x = x + assert (first is None and nb is None and init_cache is None) or ( + first is not None and nb is not None and init_cache is not None + ) + + self.first = 0 if first is None else first + self.nb = x.size(1) if nb is None else nb + self.init_cache = True if init_cache is None else init_cache + + def slice(self): + return self.x[:, self.first : self.first + self.nb] + + def complete(self): + return self.first == 0 and self.nb == self.x.size(1) + + +###################################################################### + + +class CacheWrapper(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + if bs.init_cache: + y = self.f(bs.slice()) + self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:])) + self.cache_y[:, bs.first : bs.first + bs.nb] = y + else: + assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2]) + assert bs.first + bs.nb <= self.cache_y.size(1) + self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class AddPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) + + def forward(self, bs): + if bs.init_cache: + t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ + :, None + ] + j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + None, : + ] + k = j % 2 + self.pe = torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + self.cache_y = bs.x.new(bs.x.size()) + + self.cache_y[:, bs.first : bs.first + bs.nb] = ( + bs.slice() + self.pe[bs.first : bs.first + bs.nb] + ) + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +import pscan + + +# X is /.../xTxD A is /.../xT Y_init is /.../xD + + +def pscan_dim(A, X, Y_init, dim=-2): + s = X.size() + a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel() + + A = A.reshape(a, T, *s[dim + 1 : -1]) + X = X.reshape(a, T, *s[dim + 1 : -1], -1) + + if Y_init is None: + Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1)) + else: + Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1) + + Y = pscan.pscan(A, X, Y_init).reshape(s) + + return Y + + +def pscan_shape(A, X, Y_init): + s = X.size() + A = A.reshape(-1, s[-2]) + X = X.reshape(-1, s[-2], s[-1]) + + if Y_init is None: + Y_init = X.new_zeros(X.size(0), s[-1]) + else: + Y_init = Y_init.reshape(-1, s[-1]) + + Y = pscan.pscan(A, X, Y_init).reshape(s) + + return Y + + +def nsum_shape(X, Y_init): + s = X.size() + X = X.reshape(-1, s[-2], s[-1]) # ntd + + Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1]) + result = [] + + for k in range(X.size(1)): + Y = Y + X[:, k] + Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1) + result.append(Y) + + return torch.cat(result, dim=1).reshape(s) + + +############################## + + +class DumbRec(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + nb_lines, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.nb_lines = nb_lines + self.attention_dropout = attention_dropout + + self.k_star = randw(nb_lines, dim_qk) + + self.w_qw = randw(nb_heads, dim_qk, dim_in) + self.w_qr = randw(nb_heads, dim_qk, dim_in) + # self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + warnings.warn("l2 regularization", RuntimeWarning) + return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_qw.device) + + def forward(self, bs): + x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb + + if bs.init_cache: + self.rec_v = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) + ) + # self.rec_k = x_q.new_zeros( + # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) + # ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + + ###################################################################### + # Compute the recurrent state + + qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) + + v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) + # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) + + aw = torch.einsum( + "nhtd,ltd->nhlt", + qw, + k_star, + ) / math.sqrt(self.w_qw.size(1)) + + aw = aw.softmax(dim=2) # nhlt + + if self.train: + self.acc_attention += aw.sum(dim=(0, 1, 3)) + self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) + + aw = F.dropout(aw, self.attention_dropout, self.training) + + A = 1 - aw.sum(dim=1) # nlt + + V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() + # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() + + if t0 == 0: + V0 = None + # K0 = None + else: + V0 = self.rec_v[:, :, t0 - 1] + # K0 = self.rec_k[:, :, t0 - 1] + + self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) + # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) + + ###################################################################### + # compute the readout + + qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr) + + ar = torch.einsum( + "nhtd,ld->nhlt", + qr, + # self.rec_k[:, :, t0:t1], + self.k_star, + ) / math.sqrt(self.w_qr.size(1)) + + ar = ar.softmax(dim=2) # nhlt + + ar = F.dropout(ar, self.attention_dropout, self.training) + + y = torch.einsum( + "nhlt,nltd->nthd", + ar, + self.rec_v[:, :, t0:t1], + ).flatten(2) + + self.cache_y[:, t0:t1] = y @ self.w_o + + return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache) + + +############################## + + +class KVRec(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + nb_lines, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.nb_lines = nb_lines + self.attention_dropout = attention_dropout + + self.k_star = randw(nb_lines, dim_qk) + + self.w_qw = randw(nb_heads, dim_qk, dim_in) + self.w_qr = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + warnings.warn("l2 regularization", RuntimeWarning) + return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_qw.device) + # warnings.warn("side regularization", RuntimeWarning) + # return ( + # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum() + # ) + # return torch.tensor([0], device=self.w_qw.device) + + def forward(self, bs): + x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb + + # n,h,l,t,d = dims(5) + + if bs.init_cache: + self.rec_v = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) + ) + self.rec_k = x_q.new_zeros( + x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) + ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + + ###################################################################### + # Compute the recurrent state + + qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) + + v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) + k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) + + aw = torch.einsum( + "nhtd,ltd->nhlt", + qw, + k_star, + ) / math.sqrt(self.w_qw.size(1)) + + aw = aw.softmax(dim=2) # nhlt + + if self.train: + # We want all the memory lines to be used similarly + self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT + self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) + + aw = F.dropout(aw, self.attention_dropout, self.training) + + A = 1 - aw.sum(dim=1) # nlt + + V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() + K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() + + if t0 == 0: + V0 = None + K0 = None + else: + V0 = self.rec_v[:, :, t0 - 1] + K0 = self.rec_k[:, :, t0 - 1] + + self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) + self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) + + ###################################################################### + # compute the readout + + qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr) + + ar = torch.einsum( + "nhtd,nltd->nhlt", + qr, + self.rec_k[:, :, t0:t1], + ) / math.sqrt(self.w_qr.size(1)) + + ar = ar.softmax(dim=2) # nhlt + + ar = F.dropout(ar, self.attention_dropout, self.training) + + y = torch.einsum( + "nhlt,nltd->nthd", + ar, + self.rec_v[:, :, t0:t1], + ).flatten(2) + + self.cache_y[:, t0:t1] = y @ self.w_o + + return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache) + + +############################## + + +def moving_window(x, dim, win_dim, win_size): + size, stride = x.size(), x.stride() + size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :] + size = size[:win_dim] + (win_size,) + size[win_dim:] + stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:] + + return x.as_strided(size=size, stride=stride) + + +############################## + + +class Caterpillar(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads, + caterpillar_length, + caterpillar_height, + attention_dropout=0.0, + len_max=1e5, + ): + super().__init__() + + warnings.warn("Caterpillar", RuntimeWarning) + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.caterpillar_length = caterpillar_length + self.caterpillar_height = caterpillar_height + self.attention_dropout = attention_dropout + + self.w_G = randw(nb_heads, caterpillar_height, dim_in) + self.b_G = nn.Parameter( + torch.full( + (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) + ) + ) + + self.w_K = randw(nb_heads, dim_qk, dim_in) + self.w_V = randw(nb_heads, dim_v, dim_in) + self.w_Q = randw(nb_heads, dim_qk, dim_in) + self.w_O = randw(dim_v * nb_heads, dim_in) + + self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) + self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + + def reset_inner_loss(self): + self.acc_attention = 0 + self.acc_nb = 0 + + def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + return torch.tensor([0], device=self.w_Q.device) + + def forward(self, bs): + # Dimensions to make the source a bit clearer, that's needed + + X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb + + N = bs.x.size(0) + T = bs.x.size(1) + DV = self.w_V.size(1) + DK = self.w_K.size(1) + Dout = self.w_O.size(1) + CH = self.caterpillar_height + CL = self.caterpillar_length + + assert ( + t0 >= CL and (t1 - t0) % CL == 0 + ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" + + if bs.init_cache: + self.rec_V = X.new_zeros(N, CH, T, DV) + self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] + self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.cache_Y = X.new_zeros(N, T, Dout) + + ###################################################################### + # Compute the recurrent state + + G = ( + torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + ).sigmoid() + + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + + A = 1 - G.sum(1) + gated_V = torch.einsum("nhet,nhtd->netd", G, V) + gated_K = torch.einsum("nhet,nhtd->netd", G, K) + + init_rec_V = self.rec_V[:, :, t0 - CL : t0] + init_rec_K = self.rec_K[:, :, t0 - CL : t0] + + A = A.unflatten(2, (-1, CL)) + gated_V = gated_V.unflatten(2, (-1, CL)) + gated_K = gated_K.unflatten(2, (-1, CL)) + + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + + self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) + self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + + ###################################################################### + # compute the readout + + Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) + + uv = moving_window( + self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + ) + + uk = moving_window( + self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + ) + + ar = torch.einsum( + "nhtd,nftld->nhtfl", + Q, + uk, + ) / math.sqrt(DK) + + ar = ar.flatten(3).softmax(dim=3).view(ar.size()) + + ar = F.dropout(ar, self.attention_dropout, self.training) + + Y = torch.einsum( + "nhtfl,nftld->nthd", + ar, + uv, + ).flatten(2) + + self.cache_Y[:, t0:t1] = Y @ self.w_O + + return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) + + +############################## + + +class QKVAttention(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout + self.record_attention = False + + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def forward(self, bs): + x_q = bs.x + + assert ( + self.causal or bs.complete() + ), "Partial evaluation is only possible for causal models" + + if bs.init_cache: + self.cache_k = x_q.new_zeros( + x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) + ) + self.cache_v = x_q.new_zeros( + x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1) + ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q) + + self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum( + "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k + ) + self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum( + "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v + ) + + a = torch.einsum( + "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb] + ) / math.sqrt(self.w_q.size(1)) + + if self.causal: + if bs.init_cache: + self.cache_attzero = ( + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] + ) + a = a.masked_fill( + self.cache_attzero[ + :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb + ], + float("-inf"), + ) + + a = a.softmax(dim=3) + + if self.record_attention: + self.a = a + + a = F.dropout(a, self.attention_dropout, self.training) + + y = torch.einsum( + "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb] + ).flatten(2) + + self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o + + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache) + + +############################## + + +class MyGPT(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_lines=None, + caterpillar_height=None, + dim_rec_v=-1, + causal=False, + dropout=0.0, + len_max=1e5, + attention_layer="kvrec", + ): + super().__init__() + + assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + + if attention_layer == "caterpillar": + assert nb_lines % caterpillar_height == 0 + self.caterpillar_length = nb_lines // caterpillar_height + self.caterpillar_height = caterpillar_height + else: + self.caterpillar_length = -1 + self.caterpillar_height = -1 + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + AddPositionalEncoding(len_max), + ) + + trunk_blocks = [] + + def attlayer(): + if attention_layer == "mha": + return QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + ) + elif attention_layer == "dumbrec": + return DumbRec( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + ) + elif attention_layer == "kvrec": + return KVRec( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + ) + elif attention_layer == "caterpillar": + return Caterpillar( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_rec_v, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + ) + else: + raise ValueError(f"Unknown attention type {attention_layer}.") + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + attlayer(), + ), + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + self.reset_inner_loss() + + def forward(self, bs): + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache) + + # To make the code simpler in the Caterpillar layer, we pad + # here. It's unclear if/how much it hurts computationaly by + # increasing the sequence length for the other layers + + if self.caterpillar_length > 0: + original_nb = bs.nb + if bs.nb % self.caterpillar_length > 0: + bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length + + bs = BracketedSequence( + F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)), + bs.first + self.caterpillar_length, + bs.nb, + bs.init_cache, + ) + + bs = self.embedding(bs) + bs = self.trunk(bs) + bs = self.readout(bs) + + if self.caterpillar_length > 0: + bs = BracketedSequence( + F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)), + bs.first - self.caterpillar_length, + original_nb, + bs.init_cache, + ) + + return bs + + # ar_mask is a tensor with 0s and 1s, of same shape as input, with + # 1s where tokens should be generated. The others are kept + # unchanged. + + def masked_inplace_autoregression( + self, + input_src, + ar_mask_src, + forbidden_tokens=None, + deterministic_synthesis=False, + ): + input = input_src.to(self.readout.f.weight.device) + ar_mask = ar_mask_src.to(self.readout.f.weight.device) + to_generate = (ar_mask.sum(0) > 0).nonzero() + if to_generate.min() > 0: + self( + BracketedSequence(input, 0, to_generate.min(), True) + ) # Needed to initialize the model's cache + for s in range(to_generate.min(), to_generate.max() + 1): + output = self(BracketedSequence(input, s, 1, s == 0)).x + logits = output[:, s] + if forbidden_tokens is not None: + logits = logits.masked_fill(forbidden_tokens, float("-inf")) + if deterministic_synthesis: + t_next = logits.argmax(1) + else: + dist = torch.distributions.categorical.Categorical(logits=logits) + t_next = dist.sample() + input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + + input_src.copy_(input) + + def reset_inner_loss(self): + for m in self.modules(): + if m is not self and hasattr(m, "reset_inner_loss"): + m.reset_inner_loss() + + def get_inner_loss(self): + l = torch.tensor([0.0], device=self.readout.f.weight.device) + for m in self.modules(): + if m is not self and hasattr(m, "get_inner_loss"): + l += m.get_inner_loss() + return l + + def record_attention(self, v=True): + for m in self.modules(): + if isinstance(m, QKVAttention): + m.record_attention = v + + def retrieve_attention(self): + a = [] + for m in self.modules(): + if isinstance(m, QKVAttention): + a.append(m.a) + return a + + +###################################################################### + +if __name__ == "__main__": + print("Basic check.") + + m = Caterpillar( + dim_in=4, + dim_qk=3, + dim_v=7, + nb_heads=1, + caterpillar_length=7, + caterpillar_height=3, + attention_dropout=0.0, + ) + + m.reset_inner_loss() + x = torch.randn(1, 21 + 2 * 7, 4) + y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28] + y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28] + y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21] + y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28] + print((y1 - y2).abs().max()) + print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max()) + exit(0) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + vocabulary_size = 128 + x = torch.randint(vocabulary_size, (6, 1024)) + + model = MyGPT( + vocabulary_size=vocabulary_size, + dim_model=512, + dim_keys=64, + dim_hidden=2048, + nb_heads=8, + nb_lines=128, + nb_blocks=12, + dropout=0.1, + causal=True, + ) + + x = x.to(device) + model.to(device) + + import time, sys + + # import torchvision.models as models + # from torch.profiler import profile, record_function, ProfilerActivity + + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: + # with record_function("model_inference"): + + model.eval() + for i in range(3): + start_time = time.perf_counter() + for k in range(10): + model(BracketedSequence(x)) + duration = time.perf_counter() - start_time + print(duration) + sys.stdout.flush() + + # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # print("##############################################################") + # y2 = torch.randn_like(y1) + # for s in range(x.size(1)): + # z = model(BracketedSequence(x, s, 1)) + # y2[:, s : s + 1] = z.slice() + + # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") + +###################################################################### diff --git a/picoclvr.py b/picoclvr.py new file mode 100755 index 0000000..0cd3062 --- /dev/null +++ b/picoclvr.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math +import torch, torchvision +import torch.nn.functional as F + +color_name2rgb = { + "white": [255, 255, 255], + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "black": [0, 0, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + "light_salmon": [255, 160, 122], + "orange_red": [255, 69, 0], + "dark_orange": [255, 140, 0], + "orange": [255, 165, 0], + "gold": [255, 215, 0], + "dark_golden_rod": [184, 134, 11], + "golden_rod": [218, 165, 32], + "pale_golden_rod": [238, 232, 170], + "dark_khaki": [189, 183, 107], + "khaki": [240, 230, 140], + "olive": [128, 128, 0], + "yellow_green": [154, 205, 50], + "dark_olive_green": [85, 107, 47], + "olive_drab": [107, 142, 35], + "lawn_green": [124, 252, 0], + "chartreuse": [127, 255, 0], + "green_yellow": [173, 255, 47], + "dark_green": [0, 100, 0], + "forest_green": [34, 139, 34], + "lime": [0, 255, 0], + "lime_green": [50, 205, 50], + "light_green": [144, 238, 144], + "pale_green": [152, 251, 152], + "dark_sea_green": [143, 188, 143], + "medium_spring_green": [0, 250, 154], + "spring_green": [0, 255, 127], + "sea_green": [46, 139, 87], + "medium_aqua_marine": [102, 205, 170], + "medium_sea_green": [60, 179, 113], + "light_sea_green": [32, 178, 170], + "dark_slate_gray": [47, 79, 79], + "teal": [0, 128, 128], + "dark_cyan": [0, 139, 139], + "aqua": [0, 255, 255], + "cyan": [0, 255, 255], + "light_cyan": [224, 255, 255], + "dark_turquoise": [0, 206, 209], + "turquoise": [64, 224, 208], + "medium_turquoise": [72, 209, 204], + "pale_turquoise": [175, 238, 238], + "aqua_marine": [127, 255, 212], + "powder_blue": [176, 224, 230], + "cadet_blue": [95, 158, 160], + "steel_blue": [70, 130, 180], + "corn_flower_blue": [100, 149, 237], + "deep_sky_blue": [0, 191, 255], + "dodger_blue": [30, 144, 255], + "light_blue": [173, 216, 230], + "sky_blue": [135, 206, 235], + "light_sky_blue": [135, 206, 250], + "midnight_blue": [25, 25, 112], + "navy": [0, 0, 128], + "dark_blue": [0, 0, 139], + "medium_blue": [0, 0, 205], + "royal_blue": [65, 105, 225], + "blue_violet": [138, 43, 226], + "indigo": [75, 0, 130], + "dark_slate_blue": [72, 61, 139], + "slate_blue": [106, 90, 205], + "medium_slate_blue": [123, 104, 238], + "medium_purple": [147, 112, 219], + "dark_magenta": [139, 0, 139], + "dark_violet": [148, 0, 211], + "dark_orchid": [153, 50, 204], + "medium_orchid": [186, 85, 211], + "purple": [128, 0, 128], + "thistle": [216, 191, 216], + "plum": [221, 160, 221], + "violet": [238, 130, 238], + "magenta": [255, 0, 255], + "orchid": [218, 112, 214], + "medium_violet_red": [199, 21, 133], + "pale_violet_red": [219, 112, 147], + "deep_pink": [255, 20, 147], + "hot_pink": [255, 105, 180], + "light_pink": [255, 182, 193], + "pink": [255, 192, 203], + "antique_white": [250, 235, 215], + "beige": [245, 245, 220], + "bisque": [255, 228, 196], + "blanched_almond": [255, 235, 205], + "wheat": [245, 222, 179], + "corn_silk": [255, 248, 220], + "lemon_chiffon": [255, 250, 205], + "light_golden_rod_yellow": [250, 250, 210], + "light_yellow": [255, 255, 224], + "saddle_brown": [139, 69, 19], + "sienna": [160, 82, 45], + "chocolate": [210, 105, 30], + "peru": [205, 133, 63], + "sandy_brown": [244, 164, 96], + "burly_wood": [222, 184, 135], + "tan": [210, 180, 140], + "rosy_brown": [188, 143, 143], + "moccasin": [255, 228, 181], + "navajo_white": [255, 222, 173], + "peach_puff": [255, 218, 185], + "misty_rose": [255, 228, 225], + "lavender_blush": [255, 240, 245], + "linen": [250, 240, 230], + "old_lace": [253, 245, 230], + "papaya_whip": [255, 239, 213], + "sea_shell": [255, 245, 238], + "mint_cream": [245, 255, 250], + "slate_gray": [112, 128, 144], + "light_slate_gray": [119, 136, 153], + "light_steel_blue": [176, 196, 222], + "lavender": [230, 230, 250], + "floral_white": [255, 250, 240], + "alice_blue": [240, 248, 255], + "ghost_white": [248, 248, 255], + "honeydew": [240, 255, 240], + "ivory": [255, 255, 240], + "azure": [240, 255, 255], + "snow": [255, 250, 250], + "silver": [192, 192, 192], + "gainsboro": [220, 220, 220], + "white_smoke": [245, 245, 245], +} + +color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())]) +color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())]) + +###################################################################### + + +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [] + + for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]: + s += [f"there is {c_r}"] + + if square_i[r] >= height - height // 3: + s += [f"{c_r} bottom"] + if square_i[r] < height // 3: + s += [f"{c_r} top"] + if square_j[r] >= width - width // 3: + s += [f"{c_r} right"] + if square_j[r] < width // 3: + s += [f"{c_r} left"] + + for t, c_t in [ + (k, color_id2name[square_c[k].item()]) for k in range(nb_squares) + ]: + if square_i[r] > square_i[t]: + s += [f"{c_r} below {c_t}"] + if square_i[r] < square_i[t]: + s += [f"{c_r} above {c_t}"] + if square_j[r] > square_j[t]: + s += [f"{c_r} right of {c_t}"] + if square_j[r] < square_j[t]: + s += [f"{c_r} left of {c_t}"] + + return s + + +###################################################################### + +# Generates sequences + + +def generate( + nb, + height, + width, + max_nb_squares=5, + max_nb_properties=10, + nb_colors=5, + pruner=None, +): + assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 + + descr = [] + + for n in range(nb): + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + square_position = torch.randperm(height * width)[:nb_squares] + + # color 0 is white and reserved for the background + square_c = torch.randperm(nb_colors)[:nb_squares] + 1 + square_i = square_position.div(width, rounding_mode="floor") + square_j = square_position % width + + img = torch.zeros(height * width, dtype=torch.int64) + for k in range(nb_squares): + img[square_position[k]] = square_c[k] + + # generates all the true properties + + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) + + if pruner is not None: + s = list(filter(pruner, s)) + + # picks at most max_nb_properties at random + + nb_properties = torch.randint(max_nb_properties, (1,)) + 1 + s = ( + " ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]]) + + " " + + " ".join([f"{color_id2name[n.item()]}" for n in img]) + ) + + descr += [s] + + return descr + + +###################################################################### + +# Extracts the image after in descr as a 1x3xHxW tensor + + +def descr2img(descr, height, width): + result = [] + + def token2color(t): + try: + return color_name2rgb[t] + except KeyError: + return [128, 128, 128] + + for d in descr: + d = d.split("")[1] + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width) + result.append(img) + + return torch.cat(result, 0) + + +###################################################################### + +# Returns all the properties of the image after in descr + + +def descr2properties(descr, height, width): + if type(descr) == list: + return [descr2properties(d, height, width) for d in descr] + + d = descr.split("") + img_tokens = d[-1] if len(d) > 1 else "" + img_tokens = img_tokens.strip().split(" ")[: height * width] + if len(img_tokens) != height * width: + return [] + + seen = {} + for k, x in enumerate(img_tokens): + if x != color_id2name[0]: + if x in color_name2rgb: + if x in seen: + return [] + else: + return [] + seen[x] = (color_name2id[x], k // width, k % width) + + square_infos = tuple(zip(*seen.values())) + + if square_infos: + square_c = torch.tensor(square_infos[0]) + square_i = torch.tensor(square_infos[1]) + square_j = torch.tensor(square_infos[2]) + else: + square_c = torch.tensor([]) + square_i = torch.tensor([]) + square_j = torch.tensor([]) + + s = all_properties(height, width, len(seen), square_i, square_j, square_c) + + return s + + +###################################################################### + +# Returns a triplet composed of (1) the total number of properties +# before in descr, (2) the total number of properties the image +# after verifies, and (3) the number of properties in (1) not in +# (2) + + +def nb_properties(descr, height, width, pruner=None): + if type(descr) == list: + return [nb_properties(d, height, width, pruner) for d in descr] + + d = descr.split("", 1) + if len(d) == 0: + return 0 + d = d[0].strip().split("") + d = [x.strip() for x in d] + + all_properties = set(descr2properties(descr, height, width)) + + if pruner is None: + requested_properties = set(d) + else: + requested_properties = set(filter(pruner, d)) + + missing_properties = requested_properties - all_properties + + return (len(requested_properties), len(all_properties), len(missing_properties)) + + +###################################################################### + +if __name__ == "__main__": + for n in range(16): + descr = generate(nb=1, height=12, width=16) + + print(nb_properties(descr, height=12, width=16)) + + with open(f"picoclvr_example_{n:02d}.txt", "w") as f: + for d in descr: + f.write(f"{d}\n\n") + + img = descr2img(descr, height=12, width=16) + if img.size(0) == 1: + img = F.pad(img, (1, 1, 1, 1), value=64) + + torchvision.utils.save_image( + img / 255.0, + f"picoclvr_example_{n:02d}.png", + padding=1, + nrow=4, + pad_value=0.8, + ) + + import time + + start_time = time.perf_counter() + descr = generate(nb=1000, height=12, width=16) + end_time = time.perf_counter() + print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + +###################################################################### diff --git a/problems.py b/problems.py new file mode 100755 index 0000000..9e368c2 --- /dev/null +++ b/problems.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +class Problem: + def generate_sequences(self, nb): + pass + + def seq2str(self, seq): + return "[NOT IMPLEMENTED]" + + def compute_nb_correct(self, input, ar_mask, result): + nb_total = ar_mask.sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() + return nb_total, nb_correct + + +#################### + + +class ProblemDegradation(Problem): + def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False): + assert value_max // nb_state_tokens >= 2 + self.nb_state_tokens = nb_state_tokens + self.nb_time_steps = nb_time_steps + self.value_max = value_max + self.hard = hard + + def generate_sequences(self, nb): + x = ( + torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0 + ).long() * self.value_max + seq = [x] + + for t in range(self.nb_time_steps - 1): + v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long() + u = (v.max(dim=-1, keepdim=True).values == v).long() + n = ( + (u * x) + .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size())) + .sum(dim=-1, keepdim=True) + ) + m = 1 + ((n - 1) * torch.rand(n.size())).long() + x = ( + x + + m * u.roll(shifts=-1, dims=-1) + - n * u + + (n - m) * u.roll(shifts=1, dims=-1) + ) + seq.append(x) + + if self.hard: + seq.reverse() + + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) + + def compute_nb_correct(self, input, ar_mask, result): + nb_total = result.size(0) + nb_correct = 0 + e = result.new_zeros(self.nb_state_tokens) + + for seq in result: + states = list(seq.split(self.nb_state_tokens)) + if self.hard: + states.reverse() + + d = states[0] + j = d.sort(descending=True).indices[0] + e.zero_() + e[j] = self.value_max + if (d - e).abs().sum() == 0: + nb_errors = 0 + for k in range(len(states) - 1): + d = states[k + 1] - states[k] + j = d.sort(descending=False).indices[0] + if ( + d[j] == 0 + or d[j] > self.value_max // 4 + or d[(j + 1) % e.size(0)] <= 0 + or d[(j + 1) % e.size(0)] >= -d[j] + ): + nb_errors += 1 + else: + e.zero_() + e[j] = d[j] + e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)] + e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j] + if (d - e).abs().sum() > 0: + nb_errors += 1 + if nb_errors == 0: + nb_correct += 1 + + return nb_total, nb_correct + + def seq2str(self, seq): + return " | ".join( + [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)] + ) + + +#################### + + +class ProblemMemory(Problem): + def __init__(self, len_total=32): + self.len_total = len_total + self.max_len_pattern = 5 + self.nb_noise_tokens = 10 + self.start_pattern_token = 0 + self.end_pattern_token = 1 + self.start_result_token = 2 + self.end_result_token = 3 + self.token_string = "[]<>" + "".join( + [chr(ord("a") + k) for k in range(self.nb_noise_tokens)] + ) + + def generate_sequences(self, nb): + sequences = ( + torch.randint(self.nb_noise_tokens, (nb, self.len_total)) + + self.end_result_token + + 1 + ) + len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1 + pattern_positions = torch.randint( + self.len_total - (5 + 2 * self.max_len_pattern), (nb,) + ) + k = self.len_total - (3 + self.max_len_pattern) + for i in range(nb): + l = len_patterns[i] + j = pattern_positions[i] + sequences[i, j] = self.start_pattern_token + sequences[i, j + l + 2] = self.end_pattern_token + sequences[i, k] = self.start_result_token + sequences[i, k + l + 2] = self.end_result_token + sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l] + + j = torch.arange(self.len_total)[None, :] + ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long() + + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.token_string[x.item()] for x in seq) + + +class ProblemTwoTargets(Problem): + def __init__(self, len_total=10, len_targets=3): + assert len_targets >= 3 + assert len_total >= 3 * len_targets - 1 + self.len_total = len_total + self.len_targets = len_targets + + def generate_sequences(self, nb): + k = torch.arange(self.len_total)[None, :] + s = torch.randint(10, (nb, self.len_total)) + l = torch.rand(nb, self.len_total) + l = l * (k <= self.len_total - self.len_targets).long() + k1 = l.argmax(dim=1, keepdim=True) + m = (k != k1).long() * (k != k1 + self.len_targets - 1).long() + s = s * m + 10 * (1 - m) + l = l * ( + 1 + - (k + self.len_targets - 1 >= k1).long() + * (k < k1 + self.len_targets).long() + ) + k2 = l.argmax(dim=1, keepdim=True) + m = (k != k2).long() * (k != k2 + self.len_targets - 1).long() + s = s * m + 11 * (1 - m) + a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :]) + a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :]) + sequences = torch.cat( + ( + s, + torch.full((nb, 1), 12), + a1, + torch.full((nb, 1), 12), + a2, + torch.full((nb, 1), 12), + ), + 1, + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789-+|"[x.item()] for x in seq) + + +#################### + + +class ProblemByHeart(Problem): + def __init__(self, nb_sentences=100, len_prompt=8, len_result=8): + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) + self.seq[:, len_prompt] = 10 + + def generate_sequences(self, nb): + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences == 10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|"[x.item()] for x in seq) + + +#################### + + +class ProblemLearnOperator(Problem): + def __init__(self, nb_operators=100, len_source=6, len_result=9): + self.len_source = len_source + self.len_result = len_result + self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 + self.operators = F.one_hot( + torch.rand(nb_operators, len_result, len_source).argmax(-1), + num_classes=len_source, + ) + + def generate_sequences(self, nb): + nb_operators = torch.randint(self.operators.size(0), (nb,)) + operators = self.operators[nb_operators] + nb_operators = ( + nb_operators[:, None] + // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) + ) % 10 + marker1 = torch.full((nb, 1), 10) + source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + marker2 = torch.full((nb, 1), 11) + result = operators.bmm(source[:, :, None]).squeeze(-1) + sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>"[x.item()] for x in seq) + + +#################### + + +class ProblemGuessOperator(Problem): + def __init__(self, len_source=5, len_result=8): + self.len_source = len_source + self.len_result = len_result + + def generate_sequences(self, nb): + operators = F.one_hot( + torch.rand(nb, self.len_result, self.len_source).argmax(-1), + num_classes=self.len_source, + ) + source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + marker1 = torch.full((nb, 1), 10) + result1 = operators.bmm(source1[:, :, None]).squeeze(-1) + marker2 = torch.full((nb, 1), 11) + source2 = torch.randint(10, (nb, self.len_source)) + marker3 = torch.full((nb, 1), 12) + result2 = operators.bmm(source2[:, :, None]).squeeze(-1) + + sequences = torch.cat( + (source1, marker1, result1, marker2, source2, marker3, result2), 1 + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789>|~"[x.item()] for x in seq) + + +#################### + + +class ProblemAddition(Problem): + def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False): + self.nb_digits = nb_digits + self.zero_padded = zero_padded + self.inverted_result = inverted_result + self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")]) + self.id2char = dict([(n, c) for c, n in self.char2id.items()]) + + def tensorize(self, strings): + len_max = max([len(x) for x in strings]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "$" * (len_max - len(s))] + for s in strings + ] + ) + ], + 0, + ) + + def generate_sequences(self, nb): + sequences = [] + for k in range(nb): + a, b = torch.randint(10**self.nb_digits, (2,)) + c = a + b + a, b, c = str(a.item()), str(b.item()), str(c.item()) + if self.zero_padded: + a = "0" * (self.nb_digits - len(a)) + a + b = "0" * (self.nb_digits - len(b)) + b + c = "0" * (self.nb_digits + 1 - len(c)) + c + if self.inverted_result: + c = c[::-1] + sequences.append(f"{a}+{b}={c}$") + + sequences = self.tensorize(sequences) + ar_mask = (sequences == self.char2id["="]).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.id2char[x.item()] for x in seq) + + +#################### + + +class ProblemMixing(Problem): + def __init__( + self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True + ): + self.height = height + self.width = width + self.nb_time_steps = nb_time_steps + self.hard = hard + self.random_start = random_start + + def start_random(self, nb): + y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1) + + if self.random_start: + i = ( + torch.arange(self.height) + .reshape(1, -1, 1) + .expand(nb, self.height, self.width) + ) + j = ( + torch.arange(self.width) + .reshape(1, 1, -1) + .expand(nb, self.height, self.width) + ) + + ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1) + rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1) + + m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) + + y = y * m + self.height * self.width * (1 - m) + + y = y.reshape(nb, self.height, self.width) + + return y + + def start_error(self, x): + if self.random_start: + i = ( + torch.arange(self.height, device=x.device) + .reshape(1, -1, 1) + .expand_as(x) + ) + j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x) + + ri = ( + (x == self.height * self.width) + .long() + .sum(dim=-1) + .argmax(-1) + .view(-1, 1, 1) + ) + rj = ( + (x == self.height * self.width) + .long() + .sum(dim=-2) + .argmax(-1) + .view(-1, 1, 1) + ) + + m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) + else: + m = 1 + + x = x.flatten(1) + u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1) + + d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1) + + return d + + def moves(self, x): + y = ( + x[:, None, :, :] + .expand(-1, self.height * 2 + self.width * 2, -1, -1) + .clone() + ) + k = 0 + + for i in range(self.height): + y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1) + k += 1 + y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1) + k += 1 + + for j in range(self.width): + y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1) + k += 1 + y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1) + k += 1 + + return y + + def generate_sequences(self, nb): + x = self.start_random(nb) + + seq = [x.flatten(1)] + + for t in range(self.nb_time_steps - 1): + y = self.moves(x) + x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))] + seq.append(x.flatten(1)) + + if self.hard: + seq.reverse() + + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) + + def compute_nb_correct(self, input, ar_mask, result): + a = [ + x.reshape(result.size(0), self.height, self.width) + for x in result.split(self.height * self.width, dim=1) + ] + if self.hard: + a.reverse() + + x = a[0] + + d = self.start_error(x) + + for t in range(self.nb_time_steps - 1): + x0, x = a[t], a[t + 1] + y = self.moves(x0) + d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values + + nb_total, nb_correct = result.size(0), (d == 0).long().sum().item() + + return nb_total, nb_correct + + def seq2str(self, seq): + return " | ".join( + [ + " ".join( + [ + "-".join( + [ + f"{x:02d}" if x < self.height * self.width else "**" + for x in s + ] + ) + for s in r.split(self.width) + ] + ) + for r in seq.split(self.height * self.width) + ] + ) + + +#################### + +if __name__ == "__main__": + p = ProblemMixing(height=3, width=3, random_start=False) + + s, m = p.generate_sequences(10000) + for x in s[:5]: + print(p.seq2str(x)) + print(p.compute_nb_correct(None, None, s)) diff --git a/pscan.py b/pscan.py new file mode 100755 index 0000000..0ec7b13 --- /dev/null +++ b/pscan.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch + +###################################################################### + + +class PScan(torch.autograd.Function): + # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), + # and O(log(T)) if not core-bounded, so that + # + # Y[:, 0] = Y_init + # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] + # + # can be computed as + # + # Y[:, t] = A[:, t] * Y_init + X[:, t] + + @staticmethod + def expand_(A, X): + if A.size(1) == 1: + return + T = 2 * (A.size(1) // 2) + Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + PScan.expand_(Aa[:, :, 1], Xa[:, :, 1]) + Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) + Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) + if T < A.size(1): + X[:, -1].add_(A[:, -1].mul(X[:, -2])) + A[:, -1].mul_(A[:, -2]) + + @staticmethod + def acc_rev_(A, X): + if X.size(1) == 1: + return + T = 2 * (X.size(1) // 2) + Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1])) + B = Aa[:, :, 0].clone() + B[:, 1:].mul_(Aa[:, :-1, 1]) + PScan.acc_rev_(B, Xa[:, :, 0]) + Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0])) + if T < A.size(1): + X[:, 0].add_(A[:, 1].mul(X[:, 1])) + + # A is NxT, X is NxTxD, Y_init is NxD + # + # returns Y of same shape as X, with + # + # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0 + # = A[:, t] * Y[:, t-1] + X[:, t] otherwise + + @staticmethod + def forward(ctx, A, X, Y_init): + ctx.A = A.unsqueeze(-1).clone() + ctx.Y_init = Y_init[:, None].clone() + ctx.A_star = ctx.A.clone() + ctx.X_star = X.clone() + PScan.expand_(ctx.A_star, ctx.X_star) + return ctx.A_star * ctx.Y_init + ctx.X_star + + @staticmethod + def backward(ctx, grad_output): + U = grad_output * ctx.A_star + A = ctx.A.clone() + R = grad_output.clone() + PScan.acc_rev_(A, R) + Q = ctx.Y_init.expand_as(ctx.X_star).clone() + Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1]) + return (Q * R).sum(-1), R, U.sum(dim=1) + + +pscan = PScan.apply + +###################################################################### + +if __name__ == "__main__": + import time, sys + + A = torch.rand(17, 12, 3) + X = torch.rand(17, 12, 3, 11) + Y_init = torch.rand(17, 3, 11) + Y = pscan(A, X, Y_init) + exit(0) + + N, T, D = 2, 1047, 3 + + A = torch.rand(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + + # Iterative implementation + + y = Y_init + s = 0 + + for k in range(A.size(1)): + y = A[:, k, None] * y + X[:, k] + s = s + y + + s = s.sum() + + gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( + s, (A, X, Y_init), retain_graph=True + ) + + # parallel scan + + start_time = time.perf_counter() + for _ in range(1000): + Y = pscan(A, X, Y_init) + duration = time.perf_counter() - start_time + print(f"duration {duration}") + + s = Y.sum() + + gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) + + # print(gA) + # print(gX) + # print(gY_init) + + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY_init - gY_init_ref).norm()) + + Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) + + print((Y - torch.cat([Y1, Y2], dim=1)).norm()) diff --git a/qmlp.py b/qmlp.py new file mode 100755 index 0000000..abebfc1 --- /dev/null +++ b/qmlp.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_SEND: *.py *.sh + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +nb_quantization_levels = 101 + + +def quantize(x, xmin, xmax): + return ( + ((x - xmin) / (xmax - xmin) * nb_quantization_levels) + .long() + .clamp(min=0, max=nb_quantization_levels - 1) + ) + + +def dequantize(q, xmin, xmax): + return q / nb_quantization_levels * (xmax - xmin) + xmin + + +###################################################################### + + +def generate_sets_and_params( + batch_nb_mlps, + nb_samples, + batch_size, + nb_epochs, + device=torch.device("cpu"), + print_log=False, + save_as_examples=False, +): + data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device) + data_targets = torch.zeros( + batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device + ) + + nb_rec = 8 + nb_values = 2 # more increases the min-max gap + + rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device) + + while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: + i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 + nb = i.sum() + support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1 + support = support.sort(-1).values + support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4) + + x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1 + y = ( + ( + (x[:, None, :, 0] >= support[:, :, None, 0]).long() + * (x[:, None, :, 0] <= support[:, :, None, 1]).long() + * (x[:, None, :, 1] >= support[:, :, None, 2]).long() + * (x[:, None, :, 1] <= support[:, :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + data_input[i], data_targets[i], rec_support[i] = x, y, support + + train_input, train_targets = ( + data_input[:, :nb_samples], + data_targets[:, :nb_samples], + ) + test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:] + + q_train_input = quantize(train_input, -1, 1) + train_input = dequantize(q_train_input, -1, 1) + + q_test_input = quantize(test_input, -1, 1) + test_input = dequantize(q_test_input, -1, 1) + + if save_as_examples: + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) + print(f"{xf.size()=} {x.size()=}") + yf = ( + ( + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + full_input, full_targets = xf, yf + + q_full_input = quantize(full_input, -1, 1) + full_input = dequantize(q_full_input, -1, 1) + + for k in range(q_full_input[:10].size(0)): + with open(f"example_full_{k:04d}.dat", "w") as f: + for u, c in zip(full_input[k], full_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + for k in range(q_train_input[:10].size(0)): + with open(f"example_train_{k:04d}.dat", "w") as f: + for u, c in zip(train_input[k], train_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + hidden_dim = 32 + w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2) + b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt( + hidden_dim + ) + b2 = torch.zeros(batch_nb_mlps, 2, device=device) + + w1.requires_grad_() + b1.requires_grad_() + w2.requires_grad_() + b2.requires_grad_() + optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2) + + criterion = nn.CrossEntropyLoss() + criterion.to(device) + + for k in range(nb_epochs): + acc_train_loss = 0.0 + nb_train_errors = 0 + + for input, targets in zip( + train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy( + output.reshape(-1, output.size(-1)), targets.reshape(-1) + ) + acc_train_loss += loss.item() * input.size(0) + + wta = output.argmax(-1) + nb_train_errors += (wta != targets).long().sum(-1) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + for p in [w1, b1, w2, b2]: + m = ( + torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1) + ).long() + pq = quantize(p, -2, 2) + p[...] = (1 - m) * p + m * dequantize(pq, -2, 2) + + train_error = nb_train_errors / train_input.size(1) + acc_train_loss = acc_train_loss / train_input.size(1) + + # print(f"{k=} {acc_train_loss=} {train_error=}") + + acc_test_loss = 0 + nb_test_errors = 0 + + for input, targets in zip( + test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) + acc_test_loss += loss.item() * input.size(0) + + wta = output.argmax(-1) + nb_test_errors += (wta != targets).long().sum(-1) + + test_error = nb_test_errors / test_input.size(1) + q_params = torch.cat( + [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1 + ) + q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape( + batch_nb_mlps, -1 + ) + q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape( + batch_nb_mlps, -1 + ) + + return q_train_set, q_test_set, q_params, test_error + + +###################################################################### + + +def evaluate_q_params( + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, +): + errors = [] + nb_mlps = q_params.size(0) + + for n in range(0, nb_mlps, nb_mlps_per_batch): + batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n) + batch_q_params = q_params[n : n + batch_nb_mlps] + batch_q_set = q_set[n : n + batch_nb_mlps] + hidden_dim = 32 + w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device) + b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device) + b2 = torch.empty(batch_nb_mlps, 2, device=device) + + with torch.no_grad(): + k = 0 + for p in [w1, b1, w2, b2]: + print(f"{p.size()=}") + x = dequantize( + batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2 + ).view(p.size()) + p.copy_(x) + k += p.numel() // batch_nb_mlps + + batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3) + data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device) + data_targets = batch_q_set[:, :, 2].to(device) + + print(f"{data_input.size()=} {data_targets.size()=}") + + criterion = nn.CrossEntropyLoss() + criterion.to(device) + + acc_loss = 0.0 + nb_errors = 0 + + for input, targets in zip( + data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy( + output.reshape(-1, output.size(-1)), targets.reshape(-1) + ) + acc_loss += loss.item() * input.size(0) + wta = output.argmax(-1) + nb_errors += (wta != targets).long().sum(-1) + + errors.append(nb_errors / data_input.size(1)) + acc_loss = acc_loss / data_input.size(1) + + return torch.cat(errors) + + +###################################################################### + + +def generate_sequence_and_test_set( + nb_mlps, + nb_samples, + batch_size, + nb_epochs, + device, + nb_mlps_per_batch=1024, +): + seqs, q_test_sets, test_errors = [], [], [] + + for n in range(0, nb_mlps, nb_mlps_per_batch): + q_train_set, q_test_set, q_params, test_error = generate_sets_and_params( + batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n), + nb_samples=nb_samples, + batch_size=batch_size, + nb_epochs=nb_epochs, + device=device, + ) + + seqs.append( + torch.cat( + [ + q_train_set, + q_train_set.new_full( + ( + q_train_set.size(0), + 1, + ), + nb_quantization_levels, + ), + q_params, + ], + dim=-1, + ) + ) + + q_test_sets.append(q_test_set) + test_errors.append(test_error) + + seq = torch.cat(seqs) + q_test_set = torch.cat(q_test_sets) + test_error = torch.cat(test_errors) + + return seq, q_test_set, test_error + + +###################################################################### + +if __name__ == "__main__": + import time + + batch_nb_mlps, nb_samples = 128, 250 + + generate_sets_and_params( + batch_nb_mlps=10, + nb_samples=nb_samples, + batch_size=25, + nb_epochs=100, + device=torch.device("cpu"), + print_log=False, + save_as_examples=True, + ) + + exit(0) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + start_time = time.perf_counter() + + data = [] + + seq, q_test_set, test_error = generate_sequence_and_test_set( + nb_mlps=batch_nb_mlps, + nb_samples=nb_samples, + device=device, + batch_size=25, + nb_epochs=250, + nb_mlps_per_batch=17, + ) + + end_time = time.perf_counter() + print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second") + + q_train_set = seq[:, : nb_samples * 3] + q_params = seq[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}") + error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) + print(f"train {error_train*100}%") + error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) + print(f"test {error_test*100}%") diff --git a/rpl.py b/rpl.py new file mode 100755 index 0000000..b848afa --- /dev/null +++ b/rpl.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +def rpl_exec(program, stack): + stack = stack.copy() + for op in program: + if op == "add": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a + b) + elif op == "min": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(min(a, b)) + elif op == "max": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(max(a, b)) + elif op == "swp": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a) + stack.append(b) + elif op == "rep": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack += [b] * a + elif op == "dup": + if len(stack) > 0: + a = stack.pop() + stack.append(a) + stack.append(a) + elif op == "del": + if len(stack) > 0: + a = stack.pop() + else: + raise ValueError(f"Unknown instruction {op}") + + return stack + + +rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] + +###################################################################### + + +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): + prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() + + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): + break + + return result + + +def next_marker(seq, tokens, start=0): + pos = None + for t in tokens: + try: + i = seq.index(t, start) + if pos is None or i < pos: + pos = i + except ValueError: + pass + return pos + + +def decompose(seq): + io = [] + k = 0 + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + if o is None: + raise ValueError("Missing output markers (should be correct in the prompt)") + e = next_marker(seq, ["", ""], start=o) + if e is None: + raise ValueError( + "Missing input/output markers (should be correct in the prompt)" + ) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError( + "Invalid input/output value (should be correct in the prompt)" + ) + + k = e + + if seq[k] == "": + e = next_marker(seq, [""], start=k) + if e is None: + prog = [] + else: + prog = seq[k + 1 : e] + else: + raise ValueError("Missing (it should be in the prompt)") + + return prog, io + + +def stack_distance(target_stack, result_stack): + return abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + + +def compute_nb_errors(seq): + prog, io = decompose(seq) + + nb_total, nb_errors = 0, 0 + + stacks = [] + + if len(set(prog) - set(rpl_ops)) > 0: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, ["N/A"], False)) + nb_total += len(target_stack) + nb_errors += len(target_stack) + + else: + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) + nb_total += len(target_stack) + e = stack_distance(target_stack, result_stack) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) + + return nb_total, nb_errors, prog, stacks + + +###################################################################### + +if __name__ == "__main__": + seq = generate() + print(seq) + seq[3] = 7 + print(seq) + print(compute_nb_errors(seq)) diff --git a/snake.py b/snake.py new file mode 100755 index 0000000..8a16f9f --- /dev/null +++ b/snake.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision +import torch.nn.functional as F + + +def generate_sequences( + nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") +): + worlds = torch.randint(nb_colors, (nb, height, width), device=device) + world_prior_visits = torch.zeros(nb, height, width, device=device) + + # nb x 2 + snake_position = torch.cat( + ( + torch.randint(height, (nb, 1), device=device), + torch.randint(width, (nb, 1), device=device), + ), + 1, + ) + snake_direction = torch.randint(4, (nb,), device=device) + sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) + sequences_prior_visits = torch.zeros( + nb, 2 * length, device=device, dtype=torch.int64 + ) + i = torch.arange(nb, device=device) # [:,None] + + for l in range(length): + # nb x 3 + snake_next_direction = torch.cat( + ( + (snake_direction[:, None] - 1) % 4, + snake_direction[:, None], + (snake_direction[:, None] + 1) % 4, + ), + 1, + ) + + # nb x 3 + vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1) + vw = snake_next_direction % 2 * (snake_next_direction - 2) + + # nb x 3 x 2 + snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2) + snake_next_position = snake_position[:, None, :] + snake_next_speed + + # nb x 3 + val = torch.logical_and( + torch.logical_and( + snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height + ), + torch.logical_and( + snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width + ), + ).float() + val = ( + # The multiplicative factors bias toward moving forward + torch.rand_like(val) + * val + * torch.tensor([[1.0, 2.0, 1.0]], device=device) + ) + + # nb + j = val.argmax(1) + snake_direction = snake_next_direction[i, j] + + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 + sequences_prior_visits[:, 2 * l] = world_prior_visits[ + i, snake_position[:, 0], snake_position[:, 1] + ] + if l < prompt_length: + world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + sequences[:, 2 * l + 1] = snake_direction + + # nb x 2 + snake_position = snake_next_position[i, j] + + return sequences, sequences_prior_visits, worlds, world_prior_visits + + +# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) +# exit(0) + + +def solver(input, ar_mask): + for n in range(input.size(0)): + i, j, memory = 0, 0, {} + # print(input[n]) + # print(ar_mask[n]) + for l in range(input.size(1) // 2): + if ar_mask[n, 2 * l] == 1: + if memory.get((i, j)) is None: + input[n, 2 * l] = -1 + else: + input[n, 2 * l] = memory[(i, j)] + else: + # print(f'@3 {memory=}') + if memory.get((i, j)) is None: + memory[(i, j)] = input[n, 2 * l] + else: + assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}" + # print(f'@1 {i=} {j=}') + d = input[n, 2 * l + 1].item() + i += (d + 1) % 2 * (d - 1) + j += d % 2 * (d - 2) + # print(f'@2 {i=} {j=}') + + +def seq2str(seq): + return "".join(["NESW123456789"[i] for i in seq]) + + +###################################################################### + +if __name__ == "__main__": + train_input, train_prior_visits, _, _ = generate_sequences( + nb=20, + height=9, + width=12, + nb_colors=5, + length=50, + prompt_length=100, + ) + + print([seq2str(s) for s in train_input]) + +###################################################################### diff --git a/stack.py b/stack.py new file mode 100755 index 0000000..543f04e --- /dev/null +++ b/stack.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision + +###################################################################### + +# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack +# CODE_VAL=val + 2 * nb_stacks + + +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): + stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) + stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) + k = torch.arange(nb) + result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros( + nb, (1 + nb_digits) * nb_steps, dtype=torch.int64 + ) + + for t in range(nb_steps): + op = torch.randint(2, (nb,)) + st = torch.randint(nb_stacks, (nb,)) + op = op * (stack_counts[k, st] > 0) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] + val_pop = stack[ + k, + st, + (stack_counts[k, st] - 1).clamp(min=0), + ] + stack[k, st, stack_counts[k, st]] = val_push + recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] + stack_counts[k[op == 0], st[op == 0]] += 1 + stack_counts[k[op == 1], st[op == 1]] -= 1 + result[:, (1 + nb_digits) * t] = st * 2 + op + for d in range(nb_digits): + result[:, (1 + nb_digits) * t + 1 + d] = ( + (op * val_pop + (1 - op) * val_push) // (10**d) + ) % 10 + 2 * nb_stacks + + return result.to(device), recorded_stack_counts.to(device) + + +def remove_popped_values(seq, nb_stacks, nb_digits): + m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() + for d in range(nb_digits): + k = d + 1 + seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] + + +def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): + assert seq.size(0) % (1 + nb_digits) == 0 + s = "" + for t in range(seq.size(0) // (1 + nb_digits)): + n_op = seq[(1 + nb_digits) * t] + if t > 0: + s += " " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " + s += f"POP" if n_op % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n_op//2}" + for d in range(nb_digits): + if seq[(1 + nb_digits) * t + 1 + d] == -1: + s += " ?" + else: + s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" + return s + + +###################################################################### + +if __name__ == "__main__": + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 + seq, recorded_stack_counts = generate_sequences( + nb=nb, + nb_steps=nb_steps, + nb_stacks=nb_stacks, + nb_digits=nb_digits, + ) + + for n in range(min(10, seq.size(0))): + print( + seq_to_str( + seq[n], + nb_stacks=nb_stacks, + nb_digits=nb_digits, + recorded_stack_counts=recorded_stack_counts[n], + ) + ) + # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) + + print("-- PREPARED FOR TEST -----------------") + + remove_popped_values(seq, nb_stacks, nb_digits) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) diff --git a/tasks.py b/tasks.py new file mode 100755 index 0000000..58638ed --- /dev/null +++ b/tasks.py @@ -0,0 +1,1663 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, os, tqdm + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +from mygpt import BracketedSequence + +# from graph import save_attention_image +save_attention_image = None + +###################################################################### + + +def masked_inplace_autoregression( + model, + batch_size, + input, + ar_mask, + deterministic_synthesis, + forbidden_tokens=None, + progress_bar_desc="autoregression", + device=torch.device("cpu"), +): + assert input.size() == ar_mask.size() + + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + + if progress_bar_desc is not None: + batches = tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=(input.size(0) + batch_size - 1) // batch_size, + ) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, ar_mask in batches: + model.masked_inplace_autoregression( + input, ar_mask, forbidden_tokens, deterministic_synthesis + ) + + model.train(t) + + +###################################################################### + + +class Task: + def batches(self, split="train"): + pass + + def vocabulary_size(self): + pass + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + pass + + +#################### + +import problems + + +class SandBox(Task): + def __init__( + self, + problem, + nb_train_samples, + nb_test_samples, + batch_size, + logger=None, + device=torch.device("cpu"), + max_nb_codes=1024, + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + self.problem = problem + + self.train_input, self.train_ar_mask = self.problem.generate_sequences( + nb_train_samples + ) + self.test_input, self.test_ar_mask = self.problem.generate_sequences( + nb_test_samples + ) + + self.train_input, self.train_ar_mask = self.train_input.to( + device + ), self.train_ar_mask.to(device) + self.test_input, self.test_ar_mask = self.test_input.to( + device + ), self.test_ar_mask.to(device) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + # A bit of paranoia never hurts + assert self.nb_codes <= max_nb_codes + assert self.train_input.min() >= 0 + assert self.test_input.min() >= 0 + assert tuple(x.item() for x in self.train_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + assert tuple(x.item() for x in self.test_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + + if logger is not None: + for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]): + logger(f"train_sequences {self.problem.seq2str(s)}") + a = "".join(["01"[x.item()] for x in a]) + logger(f" {a}") + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 + ): + def compute_accuracy(input, ar_mask, logger=None): + input, ar_mask = input[:nmax], ar_mask[:nmax] + result = input.clone() * (1 - ar_mask) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + log_ground_truth = ar_mask.min() == 0 + + if logger is not None: + for sp, st in zip(result[:10], input[:10]): + logger( + f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}" + ) + if log_ground_truth: + logger( + f" {n_epoch} ground truth {self.problem.seq2str(st)}" + ) + + nb_total, nb_correct = self.problem.compute_nb_correct( + input, ar_mask, result + ) + + # nb_total = ar_mask.sum().item() + # nb_correct = ((result == input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask, logger + ) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + if save_attention_image is not None: + for k in range(10): + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + + with torch.autograd.no_grad(): + t = model.training + model.eval() + # model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + # ram = model.retrieve_attention() + # model.record_attention(False) + + # tokens_output = [c for c in self.problem.seq2str(input[0])] + # tokens_input = ["n/a"] + tokens_output[:-1] + # for n_head in range(ram[0].size(1)): + # filename = os.path.join( + # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" + # ) + # attention_matrices = [m[0, n_head] for m in ram] + # save_attention_image( + # filename, + # tokens_input, + # tokens_output, + # attention_matrices, + # k_top=10, + ##min_total_attention=0.9, + # token_gap=12, + # layer_gap=50, + # ) + # logger(f"wrote {filename}") + + +###################################################################### + +import picoclvr + + +class PicoCLVR(Task): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + token_descr = [s + [""] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in token_descr] + return torch.tensor(id_descr, device=self.device) + + # Make a list of strings from a tensor + def detensorize(self, x): + return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + + # trim all the tensors in the tuple z to remove as much token from + # left and right in the first tensor. If z is a tuple, all its + # elements are trimed according to the triming for the first + def trim(self, z, token=""): + n = self.token2id[token] + if type(z) == tuple: + x = z[0] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return tuple([t[:, a:b] for t in z]) + else: + i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return z[:, a:b] + + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_colors=5, + logger=None, + device=torch.device("cpu"), + pruner_train=None, + pruner_eval=None, + ): + super().__init__() + + def generate_descr(nb, cache_suffix, pruner): + return picoclvr.generate( + nb, + height=self.height, + width=self.width, + nb_colors=nb_colors, + pruner=pruner, + ) + + self.height = height + self.width = width + self.batch_size = batch_size + self.device = device + self.pruner_train = pruner_train + self.pruner_eval = pruner_eval + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + self.train_descr = generate_descr( + nb_train_samples, "train", pruner=self.pruner_train + ) + self.test_descr = generate_descr(nb_test_samples, "test", pruner=None) + + # Build the tokenizer + tokens = {"", ""} + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_img, self.t_nul = self.token2id[""], self.token2id[""] + + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def compute_missing_properties( + self, n_epoch, model, logger, deterministic_synthesis, pruner=None + ): + acc_nb_requested_properties = [] + acc_nb_missing_properties = [] + acc_nb_results = 0 + + for input in tqdm.tqdm( + self.test_input.split(self.batch_size), + dynamic_ncols=True, + desc=f"test-properties", + ): + result = input.clone() + ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + result_descr = self.detensorize(result) + np = picoclvr.nb_properties( + result_descr, + height=self.height, + width=self.width, + pruner=pruner, + ) + nb_requested_properties, _, nb_missing_properties = zip(*np) + acc_nb_requested_properties += nb_requested_properties + acc_nb_missing_properties += nb_missing_properties + acc_nb_results += len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "" if pruner is None else "pruned_" + logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + logger( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + logger( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + logger( + f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}" + ) + + ###################################################################### + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis) + + if self.pruner_eval is not None: + self.compute_missing_properties(n_epoch, model, self.pruner_eval) + + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [] + nb_per_primer = 8 + primer = [] + + for primer_descr in [ + "red above green green top blue right of red", + "there is red there is yellow there is blue", + "red below yellow yellow below green green below blue red right yellow left green right blue left", + "green bottom yellow bottom green left of blue yellow right of blue blue top", + ]: + primer += [primer_descr + " "] * nb_per_primer + + result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul + ) + result = torch.cat((result, fill), 1) + ar_mask = (result == self.t_nul).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + result_descr = self.detensorize(result) + + np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) + + acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np) + acc_nb_results = len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "demo_" + logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + logger( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + logger( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) + + if img.dim() == 5: + if img.size(1) == 1: + img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64) + else: + img = torch.cat( + [ + torchvision.utils.make_grid(x, padding=1, pad_value=64)[None] + for x in img + ], + 0, + ) + + image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0 + ) + logger(f"wrote {image_name}") + + +###################################################################### + + +class MNIST(Task): + def __init__( + self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + ): + super().__init__() + + self.nb_train_samples = (nb_train_samples,) + self.nb_test_samples = (nb_test_samples,) + self.batch_size = batch_size + self.device = device + data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True) + self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long() + data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) + self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) + ar_mask = torch.full_like(results, 1) + masked_inplace_autoregression( + model, + self.batch_size, + results, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + logger(f"wrote {image_name}") + + +###################################################################### + +import maze + + +class Maze(Task): + def map2seq(self, *m): + return torch.cat([x.flatten(1) for x in m], 1) + + def seq2map(self, s): + s = s.reshape(s.size(0), -1, self.height, self.width) + return (s[:, k] for k in range(s.size(1))) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_walls, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + + train_mazes, train_paths, _ = maze.create_maze_data( + nb_train_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), + ) + self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) + + test_mazes, test_paths, _ = maze.create_maze_data( + nb_test_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), + ) + self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def compute_error( + self, model, split="train", nb_to_use=-1, deterministic_synthesis=False + ): + nb_total, nb_correct = 0, 0 + count = torch.zeros( + self.width * self.height, + self.width * self.height, + device=self.device, + dtype=torch.int64, + ) + + for input in self.batches(split, nb_to_use): + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + mazes, paths = self.seq2map(result) + path_correctness = maze.path_correctness(mazes, paths) + nb_correct += path_correctness.long().sum() + nb_total += mazes.size(0) + + optimal_path_lengths = ( + (input[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + predicted_path_lengths = ( + (result[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + optimal_path_lengths = optimal_path_lengths[path_correctness] + predicted_path_lengths = predicted_path_lengths[path_correctness] + count[optimal_path_lengths, predicted_path_lengths] += 1 + + if count.max() == 0: + count = None + else: + count = count[ + : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1 + ] + + return nb_total, nb_correct, count + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + train_nb_total, train_nb_correct, count = self.compute_error( + model, + "train", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct, count = self.compute_error( + model, + "test", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%") + with open( + os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" + ) as f: + for i in range(count.size(0)): + for j in range(count.size(1)): + eol = " " if j < count.size(1) - 1 else "\n" + f.write(f"{count[i,j]}{eol}") + + input = self.test_input[:48] + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + + filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png") + maze.save_image( + filename, + mazes=mazes, + target_paths=paths, + predicted_paths=predicted_paths, + path_correct=maze.path_correctness(mazes, predicted_paths), + path_optimal=maze.path_optimality(paths, predicted_paths), + ) + logger(f"wrote {filename}") + + +###################################################################### + + +import snake + + +class Snake(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_colors, + length, + prompt_length, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + self.prompt_length = prompt_length + + self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences( + nb_train_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences( + nb_test_samples, + height, + width, + nb_colors, + length, + prompt_length, + self.device, + ) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_correct(input, prior_visits): + result = input.clone() + i = torch.arange(result.size(1), device=result.device)[None, :] + ar_mask = ( + torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0) + .long() + .expand_as(result) + ) + result *= 1 - ar_mask + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + nb_total = ((prior_visits > 0) * ar_mask).sum() + + nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct( + self.test_input[:1000], self.test_prior_visits[:1000] + ) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + +###################################################################### + + +import stack + + +class Stack(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + logger, + nb_steps, + nb_stacks, + nb_digits, + fraction_values_for_train=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.nb_steps = nb_steps + self.nb_stacks = nb_stacks + self.nb_digits = nb_digits + self.device = device + + if fraction_values_for_train is None: + values_for_train = None + values_for_test = None + else: + all = torch.randperm(10**nb_digits) + nb_for_train = int(all.size(0) * fraction_values_for_train) + values_for_train = all[:nb_for_train] + values_for_test = all[nb_for_train:] + + self.train_input, self.train_stack_counts = stack.generate_sequences( + nb_train_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_train, + self.device, + ) + + self.test_input, self.test_stack_counts = stack.generate_sequences( + nb_test_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_test, + self.device, + ) + + i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks) + counts = self.test_stack_counts.flatten()[i.flatten()] + counts = F.one_hot(counts).sum(0) + logger(f"test_pop_stack_counts {counts}") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_correct(input): + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + errors = ((result != input).long() * ar_mask).reshape( + -1, 1 + self.nb_digits + ) + ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits) + + nb_total = ar_mask.max(1).values.sum() + nb_correct = nb_total - errors.max(1).values.sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000]) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + ############################################################## + # Log a few generated sequences + input = self.test_input[:10, : 12 * (1 + self.nb_digits)] + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + for n in range(result.size(0)): + logger( + f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + ############################################################## + + +###################################################################### + +import rpl + + +class RPL(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [ + self.token2id[str(c)] + for c in s + [""] * (len_max - len(s)) + ] + for s in sequences + ] + ) + ], + 0, + ) + + def seq2str(self, seq): + return " ".join([self.id2token[i] for i in seq]) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + nb_starting_values=3, + max_input=9, + prog_len=6, + nb_runs=5, + no_prog=False, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + self.no_prog = no_prog + + train_sequences = [ + rpl.generate( + nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) + for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data") + ] + + test_sequences = [ + rpl.generate( + nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) + for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") + ] + + symbols = list( + set([""] + [x for l in train_sequences + test_sequences for x in l]) + ) + val_max = max([x if type(x) is int else 0 for x in symbols]) + symbols = list(filter(lambda x: type(x) is str, symbols)) + symbols.sort() + symbols += [str(n) for n in range(val_max + 1)] + self.token2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2token = dict([(n, c) for c, n in self.token2id.items()]) + + self.t_nul = self.token2id[""] + self.t_input = self.token2id[""] + self.t_output = self.token2id[""] + self.t_prog = self.token2id[""] + self.t_end = self.token2id[""] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + if no_prog: + # Excise the program from every train and test example + k = torch.arange(self.train_input.size(1), device=self.train_input.device)[ + None, : + ] + p = ( + ((self.train_input == self.t_prog).long() * k) + .max(1, keepdim=True) + .values + ) + self.train_input = ( + self.train_input * (k <= p).long() + + self.t_end * (k == p + 1).long() + + self.t_nul * (k > p + 1).long() + ) + k = torch.arange(self.test_input.size(1), device=self.test_input.device)[ + None, : + ] + p = ( + ((self.test_input == self.t_prog).long() * k) + .max(1, keepdim=True) + .values + ) + self.test_input = ( + self.test_input * (k <= p).long() + + self.t_end * (k == p + 1).long() + + self.t_nul * (k > p + 1).long() + ) + + if logger is not None: + logger(f"value_max {val_max}") + for x in self.train_input[:25]: + end = (x != self.t_nul).nonzero().max().item() + 1 + seq = [self.id2token[i.item()] for i in x[:end]] + s = " ".join(seq) + logger(f"example_seq {s}") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.t_nul).max(0).values.nonzero().max() + 3 + batch = batch[:, :last].to(self.device) + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + # -------------------------------------------------------------------- + def compute_nb_errors_prog(input, nb_to_log=0): + result = input.clone() + s = (result == self.t_prog).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + sum_nb_total, sum_nb_errors = 0, 0 + for one_input, one_result in zip(input, result): + seq = [self.id2token[i.item()] for i in one_result] + nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) + sum_nb_total += 1 + sum_nb_errors += 0 if nb_errors == 0 else 1 + if nb_to_log > 0: + gt_seq = [self.id2token[i.item()] for i in one_input] + _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) + gt_prog = " ".join([str(x) for x in gt_prog]) + prog = " ".join([str(x) for x in prog]) + comment = "*" if nb_errors == 0 else "-" + logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]") + for start_stack, target_stack, result_stack, correct in stacks: + comment = "*" if correct else "-" + start_stack = " ".join([str(x) for x in start_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + result_stack = " ".join([str(x) for x in result_stack]) + logger( + f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]" + ) + nb_to_log -= 1 + + return sum_nb_total, sum_nb_errors + + # -------------------------------------------------------------------- + def compute_nb_errors_output(input, nb_to_log=0): + result = input.clone() + k = torch.arange(result.size(1), device=result.device)[None, :] + last_output_idx = ( + ((result == self.t_output) * k).max(dim=1, keepdim=True).values + ) + first_prog_idx = ( + ((result == self.t_prog) * k).max(dim=1, keepdim=True).values + ) + ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long() + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + sum_nb_total, sum_nb_errors = 0, 0 + for one_input, one_result, i, j in zip( + input, result, last_output_idx, first_prog_idx + ): + seq = [self.id2token[i.item()] for i in one_result] + sum_nb_total += 1 + correct = (one_input - one_result).abs().max() == 0 + sum_nb_errors += 0 if correct else 1 + if nb_to_log > 0: + result_stack = [ + self.id2token[i.item()] for i in one_result[i : j + 1] + ] + target_stack = [ + self.id2token[i.item()] for i in one_input[i : j + 1] + ] + comment = "*" if correct else "-" + result_stack = " ".join([str(x) for x in result_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + logger( + f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]" + ) + nb_to_log -= 1 + + return sum_nb_total, sum_nb_errors + + # -------------------------------------------------------------------- + + if not self.no_prog: + test_nb_total, test_nb_errors = compute_nb_errors_prog( + self.test_input[:1000].to(self.device), nb_to_log=10 + ) + + logger( + f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}") + + test_nb_total, test_nb_errors = compute_nb_errors_output( + self.test_input[:1000].to(self.device), nb_to_log=10 + ) + + logger( + f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + if save_attention_image is not None: + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + last = (input != self.t_nul).max(0).values.nonzero().max() + 3 + input = input[:, :last].to(self.device) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + ram = model.retrieve_attention() + model.record_attention(False) + + tokens_output = [self.id2token[i.item()] for i in input[0]] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(ram[0].size(1)): + filename = os.path.join( + result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf" + ) + attention_matrices = [m[0, n_head] for m in ram] + save_attention_image( + filename, + tokens_input, + tokens_output, + attention_matrices, + k_top=10, + # min_total_attention=0.9, + token_gap=12, + layer_gap=50, + ) + logger(f"wrote {filename}") + + +###################################################################### + + +import expr + + +class Expr(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "#" * (len_max - len(s))] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + nb_variables, + sequence_length, + operand_max, + result_max, + batch_size, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + + train_sequences = expr.generate_sequences( + nb_train_samples, + nb_variables=nb_variables, + length=sequence_length, + operand_max=operand_max, + result_max=result_max, + ) + + test_sequences = expr.generate_sequences( + nb_test_samples, + nb_variables=nb_variables, + length=sequence_length, + operand_max=operand_max, + result_max=result_max, + ) + + symbols = list(set("#" + "".join(train_sequences + test_sequences))) + symbols.sort() + + self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2char = dict([(n, c) for c, n in self.char2id.items()]) + + self.filler, self.space = self.char2id["#"], self.char2id[" "] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.filler).max(0).values.nonzero().max() + 3 + batch = batch[:, :last] + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def seq2str(self, s): + return "".join([self.id2char[k.item()] for k in s]) + + def produce_results( + self, + n_epoch, + model, + result_dir, + logger, + deterministic_synthesis, + input_file=None, + ): + def compute_nb_correct(input): + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + nb_total = input.size(0) + nb_correct = (input == result).long().min(1).values.sum() + + ####################################################################### + # Comput predicted vs. true variable values + + nb_delta = torch.zeros(5, dtype=torch.int64) + nb_missed = 0 + + values_input = expr.extract_results([self.seq2str(s) for s in input]) + values_result = expr.extract_results([self.seq2str(s) for s in result]) + + filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt") + + with open(filename, "w") as f: + for i, r in zip(values_input, values_result): + for n, vi in i.items(): + vr = r.get(n) + f.write(f"{vi} {-1 if vr is None else vr}\n") + + if vr is None or vr < 0: + nb_missed += 1 + else: + d = abs(vr - vi) + if d >= nb_delta.size(0): + nb_missed += 1 + else: + nb_delta[d] += 1 + + ###################################################################### + + return nb_total, nb_correct, nb_delta, nb_missed + + ( + test_nb_total, + test_nb_correct, + test_nb_delta, + test_nb_missed, + ) = compute_nb_correct(self.test_input[:10000]) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") + + nb_total = test_nb_delta.sum() + test_nb_missed + for d in range(test_nb_delta.size(0)): + logger( + f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%" + ) + logger( + f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%" + ) + + ############################################################## + # Log a few generated sequences + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) + + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler + + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + correct = (1 - ar_mask) * self.space + ar_mask * input + for n in range(result.size(0)): + comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" + logger(f"test_after {self.seq2str(result[n])} {comment}") + logger(f"truth {self.seq2str(correct[n])}") + ############################################################## + + +###################################################################### + +import grid + + +class Grid(Task): + # Make a tensor from a list of strings + def str2tensor(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + token_descr = [s + ["#"] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in token_descr] + return torch.tensor(id_descr, device=self.device) + + # Make a list of strings from a tensor + def tensor2str(self, x): + return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + + # trim all the tensors in the tuple z to remove as much token from + # left and right in the first tensor. If z is a tuple, all its + # elements are trimed according to the triming for the first + def trim(self, z, token="#"): + n = self.token2id[token] + if type(z) == tuple: + x = z[0] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return tuple([t[:, a:b] for t in z]) + else: + i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return z[:, a:b] + + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + size, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.device = device + self.batch_size = batch_size + self.grid_factory = grid.GridFactory(size=size) + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + self.train_descr = self.grid_factory.generate_samples( + nb_train_samples, lambda r: tqdm.tqdm(r) + ) + self.test_descr = self.grid_factory.generate_samples( + nb_test_samples, lambda r: tqdm.tqdm(r) + ) + + # Build the tokenizer + tokens = set() + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + tokens = ["#"] + tokens + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_nul = self.token2id["#"] + self.t_true = self.token2id["true"] + self.t_false = self.token2id["false"] + + # Tokenize the train and test sets + self.train_input = self.str2tensor(self.train_descr) + self.test_input = self.str2tensor(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + correct = self.test_input[:1000] + result = correct.clone() + ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long() + result *= 1 - ar_mask # paraaaaanoiaaaaaaa + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_before {e}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_after {e}") + + logger(f"----------------------------------------------------------") + + nb_total = ar_mask.sum().item() + nb_correct = ((correct == result).long() * ar_mask).sum().item() + + logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") + logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") + + +###################################################################### + +import qmlp + + +class QMLP(Task): + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + result_dir, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.device = device + self.batch_size = batch_size + self.nb_samples_per_mlp = 256 + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set( + nb_mlps=nb_train_samples + nb_test_samples, + nb_samples=self.nb_samples_per_mlp, + device=self.device, + batch_size=64, + nb_epochs=250, + nb_mlps_per_batch=1024, + ) + + self.train_input = seq[:nb_train_samples] + self.train_q_test_set = q_test_set[:nb_train_samples] + self.train_ref_test_errors = test_error[:nb_train_samples] + self.test_input = seq[nb_train_samples:] + self.test_q_test_set = q_test_set[nb_train_samples:] + self.test_ref_test_errors = test_error[nb_train_samples:] + + filename = os.path.join(result_dir, f"train_errors_ref.dat") + with open(filename, "w") as f: + for e in self.train_ref_test_errors: + f.write(f"{e}\n") + + filename = os.path.join(result_dir, f"test_errors_ref.dat") + with open(filename, "w") as f: + for e in self.test_ref_test_errors: + f.write(f"{e}\n") + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + correct = self.test_input[:1000] + result = correct.clone() + ar_mask = ( + torch.arange(result.size(1), device=result.device) + > self.nb_samples_per_mlp * 3 + 1 + ).long()[None, :] + ar_mask = ar_mask.expand_as(result) + result *= 1 - ar_mask # paraaaaanoiaaaaaaa + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + q_train_set = result[:, : self.nb_samples_per_mlp * 3] + q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :] + error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set) + + filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat") + with open(filename, "w") as f: + for e in error_test: + f.write(f"{e}\n") + + +###################################################################### diff --git a/world.py b/world.py new file mode 100755 index 0000000..aad0bfb --- /dev/null +++ b/world.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, tqdm + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F +import cairo + +###################################################################### + + +class Box: + nb_rgb_levels = 10 + + def __init__(self, x, y, w, h, r, g, b): + self.x = x + self.y = y + self.w = w + self.h = h + self.r = r + self.g = g + self.b = b + + def collision(self, scene): + for c in scene: + if ( + self is not c + and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w) + and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h) + ): + return True + return False + + +###################################################################### + + +class Normalizer(nn.Module): + def __init__(self, mu, std): + super().__init__() + self.register_buffer("mu", mu) + self.register_buffer("log_var", 2 * torch.log(std)) + + def forward(self, x): + return (x - self.mu) / torch.exp(self.log_var / 2.0) + + +class SignSTE(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # torch.sign() takes three values + s = (x >= 0).float() * 2 - 1 + + if self.training: + u = torch.tanh(x) + return s + u - u.detach() + else: + return s + + +class DiscreteSampler2d(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + s = (x >= x.max(-3, keepdim=True).values).float() + + if self.training: + u = x.softmax(dim=-3) + return s + u - u.detach() + else: + return s + + +def loss_H(binary_logits, h_threshold=1): + p = binary_logits.sigmoid().mean(0) + h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2) + h.clamp_(max=h_threshold) + return h_threshold - h.mean() + + +def train_encoder( + train_input, + test_input, + depth, + nb_bits_per_token, + dim_hidden=48, + lambda_entropy=0.0, + lr_start=1e-3, + lr_end=1e-4, + nb_epochs=10, + batch_size=25, + logger=None, + device=torch.device("cpu"), +): + mu, std = train_input.float().mean(), train_input.float().std() + + def encoder_core(depth, dim): + l = [ + [ + nn.Conv2d( + dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2 + ), + nn.ReLU(), + nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2), + nn.ReLU(), + ] + for k in range(depth) + ] + + return nn.Sequential(*[x for m in l for x in m]) + + def decoder_core(depth, dim): + l = [ + [ + nn.ConvTranspose2d( + dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2 + ), + nn.ReLU(), + nn.ConvTranspose2d( + dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2 + ), + nn.ReLU(), + ] + for k in range(depth - 1, -1, -1) + ] + + return nn.Sequential(*[x for m in l for x in m]) + + encoder = nn.Sequential( + Normalizer(mu, std), + nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1), + nn.ReLU(), + # 64x64 + encoder_core(depth=depth, dim=dim_hidden), + # 8x8 + nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1), + ) + + quantizer = SignSTE() + + decoder = nn.Sequential( + nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1), + # 8x8 + decoder_core(depth=depth, dim=dim_hidden), + # 64x64 + nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1), + ) + + model = nn.Sequential(encoder, decoder) + + nb_parameters = sum(p.numel() for p in model.parameters()) + + logger(f"vqae nb_parameters {nb_parameters}") + + model.to(device) + + for k in range(nb_epochs): + lr = math.exp( + math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k + ) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + acc_train_loss = 0.0 + + for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): + input = input.to(device) + z = encoder(input) + zq = quantizer(z) + output = decoder(zq) + + output = output.reshape( + output.size(0), -1, 3, output.size(2), output.size(3) + ) + + train_loss = F.cross_entropy(output, input) + + if lambda_entropy > 0: + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) + + acc_train_loss += train_loss.item() * input.size(0) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + acc_test_loss = 0.0 + + for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"): + input = input.to(device) + z = encoder(input) + zq = quantizer(z) + output = decoder(zq) + + output = output.reshape( + output.size(0), -1, 3, output.size(2), output.size(3) + ) + + test_loss = F.cross_entropy(output, input) + + acc_test_loss += test_loss.item() * input.size(0) + + train_loss = acc_train_loss / train_input.size(0) + test_loss = acc_test_loss / test_input.size(0) + + logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") + sys.stdout.flush() + + return encoder, quantizer, decoder + + +###################################################################### + + +def scene2tensor(xh, yh, scene, size): + width, height = size, size + pixel_map = torch.ByteTensor(width, height, 4).fill_(255) + data = pixel_map.numpy() + surface = cairo.ImageSurface.create_for_data( + data, cairo.FORMAT_ARGB32, width, height + ) + + ctx = cairo.Context(surface) + ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD) + + for b in scene: + ctx.move_to(b.x * size, b.y * size) + ctx.rel_line_to(b.w * size, 0) + ctx.rel_line_to(0, b.h * size) + ctx.rel_line_to(-b.w * size, 0) + ctx.close_path() + ctx.set_source_rgba( + b.r / (Box.nb_rgb_levels - 1), + b.g / (Box.nb_rgb_levels - 1), + b.b / (Box.nb_rgb_levels - 1), + 1.0, + ) + ctx.fill() + + hs = size * 0.1 + ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0) + ctx.move_to(xh * size - hs / 2, yh * size - hs / 2) + ctx.rel_line_to(hs, 0) + ctx.rel_line_to(0, hs) + ctx.rel_line_to(-hs, 0) + ctx.close_path() + ctx.fill() + + return ( + pixel_map[None, :, :, :3] + .flip(-1) + .permute(0, 3, 1, 2) + .long() + .mul(Box.nb_rgb_levels) + .floor_divide(256) + ) + + +def random_scene(nb_insert_attempts=3): + scene = [] + colors = [ + ((Box.nb_rgb_levels - 1), 0, 0), + (0, (Box.nb_rgb_levels - 1), 0), + (0, 0, (Box.nb_rgb_levels - 1)), + ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0), + ( + (Box.nb_rgb_levels * 2) // 3, + (Box.nb_rgb_levels * 2) // 3, + (Box.nb_rgb_levels * 2) // 3, + ), + ] + + for k in range(nb_insert_attempts): + wh = torch.rand(2) * 0.2 + 0.2 + xy = torch.rand(2) * (1 - wh) + c = colors[torch.randint(len(colors), (1,))] + b = Box( + xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2] + ) + if not b.collision(scene): + scene.append(b) + + return scene + + +def generate_episode(steps, size=64): + delta = 0.1 + effects = [ + (False, 0, 0), + (False, delta, 0), + (False, 0, delta), + (False, -delta, 0), + (False, 0, -delta), + (True, delta, 0), + (True, 0, delta), + (True, -delta, 0), + (True, 0, -delta), + ] + + while True: + frames = [] + + scene = random_scene() + xh, yh = tuple(x.item() for x in torch.rand(2)) + + actions = torch.randint(len(effects), (len(steps),)) + nb_changes = 0 + + for s, a in zip(steps, actions): + if s: + frames.append(scene2tensor(xh, yh, scene, size=size)) + + grasp, dx, dy = effects[a] + + if grasp: + for b in scene: + if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh: + x, y = b.x, b.y + b.x += dx + b.y += dy + if ( + b.x < 0 + or b.y < 0 + or b.x + b.w > 1 + or b.y + b.h > 1 + or b.collision(scene) + ): + b.x, b.y = x, y + else: + xh += dx + yh += dy + nb_changes += 1 + else: + x, y = xh, yh + xh += dx + yh += dy + if xh < 0 or xh > 1 or yh < 0 or yh > 1: + xh, yh = x, y + + if nb_changes > len(steps) // 3: + break + + return frames, actions + + +###################################################################### + + +def generate_episodes(nb, steps): + all_frames, all_actions = [], [] + for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"): + frames, actions = generate_episode(steps) + all_frames += frames + all_actions += [actions[None, :]] + return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0) + + +def create_data_and_processors( + nb_train_samples, + nb_test_samples, + mode, + nb_steps, + depth=3, + nb_bits_per_token=8, + nb_epochs=10, + device=torch.device("cpu"), + device_storage=torch.device("cpu"), + logger=None, +): + assert mode in ["first_last"] + + if mode == "first_last": + steps = [True] + [False] * (nb_steps + 1) + [True] + + if logger is None: + logger = lambda s: print(s) + + train_input, train_actions = generate_episodes(nb_train_samples, steps) + train_input, train_actions = train_input.to(device_storage), train_actions.to( + device_storage + ) + test_input, test_actions = generate_episodes(nb_test_samples, steps) + test_input, test_actions = test_input.to(device_storage), test_actions.to( + device_storage + ) + + encoder, quantizer, decoder = train_encoder( + train_input, + test_input, + depth=depth, + nb_bits_per_token=nb_bits_per_token, + lambda_entropy=1.0, + nb_epochs=nb_epochs, + logger=logger, + device=device, + ) + encoder.train(False) + quantizer.train(False) + decoder.train(False) + + z = encoder(train_input[:1].to(device)) + pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :] + z_h, z_w = z.size(2), z.size(3) + + logger(f"vqae input {train_input[0].size()} output {z[0].size()}") + + def frame2seq(input, batch_size=25): + seq = [] + p = pow2.to(device) + for x in input.split(batch_size): + x = x.to(device) + z = encoder(x) + ze_bool = (quantizer(z) >= 0).long() + output = ( + ze_bool.permute(0, 2, 3, 1).reshape( + ze_bool.size(0), -1, ze_bool.size(1) + ) + * p + ).sum(-1) + + seq.append(output) + + return torch.cat(seq, dim=0) + + def seq2frame(input, batch_size=25, T=1e-2): + frames = [] + p = pow2.to(device) + for seq in input.split(batch_size): + seq = seq.to(device) + zd_bool = (seq[:, :, None] // p) % 2 + zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2) + logits = decoder(zd_bool * 2.0 - 1.0) + logits = logits.reshape( + logits.size(0), -1, 3, logits.size(2), logits.size(3) + ).permute(0, 2, 3, 4, 1) + output = torch.distributions.categorical.Categorical( + logits=logits / T + ).sample() + + frames.append(output) + + return torch.cat(frames, dim=0) + + return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame + + +###################################################################### + +if __name__ == "__main__": + ( + train_input, + train_actions, + test_input, + test_actions, + frame2seq, + seq2frame, + ) = create_data_and_processors( + 25000, + 1000, + nb_epochs=5, + mode="first_last", + nb_steps=20, + ) + + input = test_input[:256] + + seq = frame2seq(input) + output = seq2frame(seq) + + torchvision.utils.save_image( + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 + ) + + torchvision.utils.save_image( + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 + )