From: François Fleuret Date: Sat, 21 Sep 2024 03:15:35 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=HEAD;hp=00f7b3d445af8bb57376faabbf74eadc145faf1f;p=culture.git Update. --- diff --git a/attae.py b/attae.py new file mode 100755 index 0000000..c04c5d3 --- /dev/null +++ b/attae.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +import math + +import torch + +from torch import nn +from torch.nn import functional as F + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 # works with float, weird + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + y = x + pe + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +def vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x) + return x + + +###################################################################### + + +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_work_tokens=100, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x + + +###################################################################### + + +if __name__ == "__main__": + model = FunctionalAttentionAE( + vocabulary_size=100, + dim_model=16, + dim_keys=64, + dim_hidden=32, + nb_heads=4, + nb_work_tokens=10, + nb_blocks=4, + dropout=0.1, + ) + + x = torch.randint(100, (10, 50)) + y = model(x) + + with torch.no_grad(): + model.eval() + x = torch.randint(100, (10, 50)) + y = model(x) + + print(y) diff --git a/grids.py b/grids.py index 47e5861..78d9297 100755 --- a/grids.py +++ b/grids.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm, os, warnings +import math, sys, tqdm, os, warnings, cairo, re import torch, torchvision @@ -14,173 +14,478 @@ from torch.nn import functional as F ###################################################################### + +def text_img(height, width, text): + pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8) + + surface = cairo.ImageSurface.create_for_data( + pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0) + ) + + ctx = cairo.Context(surface) + ctx.set_source_rgb(0, 0, 0) + ctx.set_font_size(16) + ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + y = None + for line in text.split("\n"): + xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line) + if y is None: + y = height * 1.5 + x = height * 0.5 + + ctx.move_to(x, y) + ctx.show_text(line) + y += height * 1.5 + + ctx.stroke() + + return pixel_map.permute(2, 0, 1)[None, :3].contiguous() + + +###################################################################### + import problem +def grow_islands(nb, height, width, nb_seeds, nb_iterations): + w = torch.empty(5, 1, 3, 3) + + w[0, 0] = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + ] + ) + + w[1, 0] = torch.tensor( + [ + [-1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + w[2, 0] = torch.tensor( + [ + [0.0, 1.0, -1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + ] + ) + + w[3, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, -1.0], + ] + ) + + w[4, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [-1.0, 1.0, 0.0], + ] + ) + + Z = torch.zeros(nb, height, width) + U = Z.flatten(1) + + for _ in range(nb_seeds): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 0] == 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U += M * Q + + for _ in range(nb_iterations): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 1] >= 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U = Z.flatten(1) + U += M * Q + + M = Z.clone() + Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2)) + + while True: + W = Z.clone() + Z = F.max_pool2d(Z, 3, 1, 1) * M + if Z.equal(W): + break + + Z = Z.long() + U = Z.flatten(1) + V = F.one_hot(U).max(dim=1).values + W = V.cumsum(dim=1) - V + N = torch.arange(Z.size(0))[:, None, None].expand_as(Z) + Z = W[N, Z] + + return Z + + class Grids(problem.Problem): + # grid_gray = 64 + # thickness = 1 + # background_gray = 255 + # dots = False + + grid_gray = 240 + thickness = 0 + background_gray = 240 + dots = False + + # grid_gray = 192 + # thickness = 0 + # background_gray = 255 + # dots = True + named_colors = [ - ("white", [255, 255, 255]), + ("white", [background_gray, background_gray, background_gray]), + # ("white", [224, 224, 224]), ("red", [255, 0, 0]), - ("green", [0, 192, 0]), + ("green", [0, 160, 0]), ("blue", [0, 0, 255]), ("yellow", [255, 224, 0]), ("cyan", [0, 255, 255]), ("violet", [224, 128, 255]), - ("lightgreen", [192, 255, 192]), + ("lightgreen", [160, 255, 160]), ("brown", [165, 42, 42]), ("lightblue", [192, 192, 255]), ("gray", [128, 128, 128]), ] - def __init__(self, device=torch.device("cpu")): + def pure_noise(self, nb, device): + result = torch.randint( + self.nb_colors, (nb, 4 * (self.height * self.height)), device=device + ) + return result + + def trivial(self, quizzes): + S = self.height * self.width + assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B")) + a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min( + dim=1 + ).values + + def text2quiz(self, t): + chr2col = [ + (".", "white"), + ("r", "red"), + ("g", "green"), + ("b", "blue"), + ("y", "yellow"), + ("c", "cyan"), + ("v", "violet"), + ("l", "lightgreen"), + ("o", "brown"), + ("l", "lightblue"), + ("a", "gray"), + ] + + col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)]) + chr2tok = dict([(c, col2tok[col]) for c, col in chr2col]) + + t = re.sub(r"#.*\n", "", t).strip() + l = t.replace("\n\n", ";").split(";") + + result = [] + + for t in l: + t = "".join(t.replace("\n", " ").strip().split(" ")) + t = torch.tensor([chr2tok[c] for c in t]) + t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1) + t = torch.cat( + [ + torch.tensor( + [ + [self.token_A], + [self.token_f_A], + [self.token_B], + [self.token_f_B], + ] + ), + t, + ], + dim=1, + ) + result.append(t.flatten()[None, :]) + + return torch.cat(result, dim=0) + + def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): + S = self.height * self.width + q = quizzes.reshape(quizzes.size(0), 4, S + 1) + return ( + (q[:, 0, 0] == self.l2tok[quad_order[0]]) + & (q[:, 1, 0] == self.l2tok[quad_order[1]]) + & (q[:, 2, 0] == self.l2tok[quad_order[2]]) + & (q[:, 3, 0] == self.l2tok[quad_order[3]]) + ) + + def __init__( + self, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, + tasks=None, + ): self.colors = torch.tensor([c for _, c in self.named_colors]) + + self.nb_colors = len(self.colors) + + self.nb_rec_max = 5 + self.rfree = torch.tensor([]) + self.height = 10 self.width = 10 - self.device = device + self.seq_len = 4 * self.height * self.width - ###################################################################### + self.cache_rec_coo = {} + + all_tasks = [ + ############################################ fundamental ones + self.task_replace_color, + self.task_translate, + self.task_grow, + self.task_frame, + ############################################ + ############################################ + self.task_half_fill, + self.task_detect, + self.task_scale, + self.task_symbols, + self.task_corners, + self.task_contact, + self.task_path, + self.task_fill, + ############################################ hard ones + self.task_isometry, + self.task_trajectory, + self.task_bounce, + # self.task_count, # NOT REVERSIBLE + # self.task_islands, # TOO MESSY + ] + + if tasks is None: + self.all_tasks = all_tasks + else: + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] - def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) - m = torch.logical_and(x >= 0, x < self.nb_token_values()).long() - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] + ###################################################################### + + def vocabulary_size(self): + # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + # return self.nb_colors+4 + return self.nb_colors + + def grid2img(self, x, scale=15, grids=True): + m = torch.logical_and(x >= 0, x < self.nb_colors).long() + y = self.colors[x * m].permute(0, 3, 1, 2) + s = y.shape + y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + + if grids: + for t in range(self.thickness): + y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray + y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray + if self.dots: + z = y.reshape( + y.size(0), + y.size(1), + y.size(2) // scale, + scale, + y.size(3) // scale, + scale, + ) + z = z[ + :, + :, + :, + scale // 2 - 1 : scale // 2 + 2, + :, + scale // 2 - 1 : scale // 2 + 2, + ] + zz = (z == self.background_gray).min(dim=1, keepdim=True).values + z[...] = zz * self.grid_gray + (zz == False) * z for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 + if x[n, i, j] >= self.nb_colors: + # for k in range(3, scale - 2): + c = self.colors[x[n, i, j] - self.nb_colors][:, None, None] + # y[n, :, i * scale + k, j * scale + k] = c + # y[n, :, i * scale + k, j * scale + scale - k] = c + y[ + n, + :, + i * scale + 3 : i * scale + scale - 2, + j * scale + 3 : j * scale + scale - 2, + ] = c + + y = y[:, :, 1:, 1:] + + return y + + def add_frame(self, img, colors, thickness): + if thickness > 0: + result = img.new( + img.size(0), + img.size(1), + img.size(2) + 2 * thickness, + img.size(3) + 2 * thickness, + ) + + result[...] = colors[:, :, None, None] + result[:, :, thickness:-thickness, thickness:-thickness] = img + else: + result = img - return x + return result - def save_image( + def save_quizzes_as_image( self, result_dir, filename, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, + quizzes, + predicted_parts=None, + correct_parts=None, + comments=None, + comment_height=48, nrow=4, - margin=8, + grids=True, + margin=12, + delta=False, + delta_highlight=False, ): - S = self.height * self.width - As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) - f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( - -1, self.height, self.width - ) - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) - prompts = torch.cat([As, f_As, Bs], dim=2) - answers = answers.reshape(answers.size(0), self.height, self.width) - - if predicted_prompts is None: - predicted_prompts = 255 + quizzes = quizzes.to("cpu") - if predicted_answers is None: - predicted_answers = 255 - - def add_frame(x, c, margin, bottom=False): - if bottom: - h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 - else: - h, w, di, dj = ( - x.size(2) + 2 * margin, - x.size(3) + 2 * margin, - margin, - margin, - ) - - y = x.new_full((x.size(0), x.size(1), h, w), 0) - - if type(c) is int: - y[...] = c - else: - c = c.long()[:, None] - c = ( - (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) - * torch.tensor([64, 64, 64], device=c.device) - + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) - + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) - + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) - ) - y[...] = c[:, :, None, None] - - y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x + S = self.height * self.width - return y + A, f_A, B, f_B = ( + quizzes.reshape(quizzes.size(0), 4, S) + .reshape(quizzes.size(0), 4, self.height, self.width) + .permute(1, 0, 2, 3) + ) - img_prompts = torch.cat( + frame, white, gray, green, red = torch.tensor( [ - add_frame( - add_frame(self.frame2img(x), c=0, margin=1), - c=predicted_prompts, - margin=margin, - ) - for x in prompts.to("cpu").split(split_size=self.width, dim=2) + [self.grid_gray, self.grid_gray, self.grid_gray], + [255, 255, 255], + [200, 200, 200], + [0, 255, 0], + [255, 0, 0], ], - dim=3, + device=quizzes.device, ) - h = img_prompts.size(2) - img_answers = add_frame( - add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1), - c=predicted_answers, - margin=margin, - ) + thickness = self.thickness - separator_size = 2 * margin + if delta: + u = (A != f_A).long() + img_delta_A = self.add_frame( + self.grid2img(u, grids=grids), frame[None, :], thickness=thickness + ) + img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as( + img_delta_A + ) + u = (B != f_B).long() + img_delta_B = self.add_frame( + self.grid2img(u, grids=grids), frame[None, :], thickness=thickness + ) + img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as( + img_delta_B + ) - separator = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, + img_A = self.add_frame( + self.grid2img(A, grids=grids), frame[None, :], thickness=thickness ) - - marker = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, + img_f_A = self.add_frame( + self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness + ) + img_B = self.add_frame( + self.grid2img(B, grids=grids), frame[None, :], thickness=thickness + ) + img_f_B = self.add_frame( + self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness ) - # marker[:, :, 0] = 0 - # marker[:, :, h - 1] = 0 + if delta_highlight: + q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long() + img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B + + # predicted_parts Nx4 + # correct_parts Nx4 + + if predicted_parts is None: + colors = white[None, None, :].expand(-1, 4, -1) + else: + predicted_parts = predicted_parts.to("cpu") + if correct_parts is None: + colors = ( + predicted_parts[:, :, None] * gray[None, None, :] + + (1 - predicted_parts[:, :, None]) * white[None, None, :] + ) + else: + correct_parts = correct_parts.to("cpu") + colors = ( + predicted_parts[:, :, None] + * ( + (correct_parts[:, :, None] == 1).long() * green[None, None, :] + + (correct_parts[:, :, None] == 0).long() * gray[None, None, :] + + (correct_parts[:, :, None] == -1).long() * red[None, None, :] + ) + + (1 - predicted_parts[:, :, None]) * white[None, None, :] + ) + + separation = 6 - for k in range(1, 2 * separator_size - 8): - i = k - (separator_size - 4) - j = separator_size - 5 - abs(i) - marker[:, :, h // 2 - 1 + i, 2 + j] = 0 - marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + img_A = self.add_frame(img_A, colors[:, 0], thickness=separation) + img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation) + img_B = self.add_frame(img_B, colors[:, 2], thickness=separation) + img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation) - img = torch.cat( - [ - img_prompts, - marker, - img_answers, - ], - dim=3, - ) + img_A = self.add_frame(img_A, white[None, :], thickness=2) + img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2) + img_B = self.add_frame(img_B, white[None, :], thickness=2) + img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) + + if delta: + img_delta_A = self.add_frame( + img_delta_A, colors[:, 0], thickness=separation + ) + img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2) + img_delta_B = self.add_frame( + img_delta_B, colors[:, 0], thickness=separation + ) + img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2) + img = torch.cat( + [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3 + ) + else: + img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) + + if comments is not None: + comment_img = [text_img(comment_height, img.size(3), t) for t in comments] + comment_img = torch.cat(comment_img, dim=0) + img = torch.cat([img, comment_img], dim=2) image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image( img.float() / 255.0, image_name, @@ -191,147 +496,283 @@ class Grids(problem.Problem): ###################################################################### - def nb_token_values(self): - return len(self.colors) - # @torch.compile - def rec_coo_(self, nb_rec, min_height=3, min_width=3): - # @torch.compile - def overlap(ia, ja, ib, jb): - return ( - ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1] - ) + def rec_coo( + self, + nb_rec, + min_height=3, + min_width=3, + surface_max=None, + prevent_overlap=False, + ): + if surface_max is None: + surface_max = self.height * self.width // 2 - if nb_rec == 3: + signature = (nb_rec, min_height, min_width, surface_max) + + try: + return self.cache_rec_coo[signature].pop() + except IndexError: + pass + except KeyError: + pass + + N = 10000 + while True: while True: - i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values - j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values - if ( - not ( - overlap(i[0], j[0], i[1], j[1]) - or overlap(i[0], j[0], i[2], j[2]) - or overlap(i[1], j[1], i[2], j[2]) - ) - and (i[:, 1] - i[:, 0]).min() >= min_height - and (j[:, 1] - j[:, 0]).min() >= min_width - ): + i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values + j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values + i[:, 1] += 1 + j[:, 1] += 1 + big_enough = ( + (i[:, 1] >= i[:, 0] + min_height) + & (j[:, 1] >= j[:, 0] + min_height) + & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max) + ) + + i, j = i[big_enough], j[big_enough] + + n = i.size(0) - i.size(0) % nb_rec + + if n > 0: break - return ( - (i[0, 0], j[0, 0], i[0, 1], j[0, 1]), - (i[1, 0], j[1, 0], i[1, 1], j[1, 1]), - (i[2, 0], j[2, 0], i[2, 1], j[2, 1]), - ) - # That's quite a tensorial spaghetti mess to sample - # non-overlapping rectangles quickly, but made the generation of - # 100k samples go from 1h50 with a lame pure python code to 3min30s - # with this one. - # @torch.compile - def rec_coo(self, nb_rec, min_height=3, min_width=3): - nb_trials = 200 + i = i[:n].reshape(n // nb_rec, nb_rec, -1) + j = j[:n].reshape(n // nb_rec, nb_rec, -1) + + if prevent_overlap: + can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum( + dim=-1 + ) <= self.height * self.width + i, j = i[can_fit], j[can_fit] + if nb_rec == 2: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + no_overlap = ( + (A_i1 >= B_i2) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) + ) + i, j = (i[no_overlap], j[no_overlap]) + elif nb_rec == 3: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + C_i1, C_i2, C_j1, C_j2 = ( + i[:, 2, 0], + i[:, 2, 1], + j[:, 2, 0], + j[:, 2, 1], + ) + no_overlap = ( + ( + (A_i1 >= B_i2) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) + ) + & ( + (A_i1 >= C_i2) + | (A_i2 <= C_i1) + | (A_j1 >= C_j2) + | (A_j2 <= C_j1) + ) + & ( + (B_i1 >= C_i2) + | (B_i2 <= C_i1) + | (B_j1 >= C_j2) + | (B_j2 <= C_j1) + ) + ) + i, j = (i[no_overlap], j[no_overlap]) + else: + assert nb_rec == 1 - while True: - v = ( + if i.size(0) > 1: + break + + self.cache_rec_coo[signature] = [ + [ ( - torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 + i[n, k, 0].item(), + j[n, k, 0].item(), + i[n, k, 1].item(), + j[n, k, 1].item(), ) - .long() - .cumsum(dim=1) - == 1 - ).long() + for k in range(nb_rec) + ] + for n in range(i.size(0)) + ] + + return self.cache_rec_coo[signature].pop() - h = ( + ###################################################################### + + def contact_matrices(self, rn, ri, rj, rz): + n = torch.arange(self.nb_rec_max) + return ( + ( ( - torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 + ( + (ri[:, :, None, 0] == ri[:, None, :, 1] + 1) + | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0]) + ) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) ) - .long() - .cumsum(dim=1) - == 1 - ).long() + | ( + ( + (rj[:, :, None, 0] == rj[:, None, :, 1] + 1) + | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0]) + ) + & (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + ) + ) + # & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) - i = torch.logical_and( - v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width + def sample_rworld_states(self, N=1000): + while True: + ri = ( + torch.randint(self.height - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + ri[:, :, 1] += 2 + rj = ( + torch.randint(self.width - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + rj[:, :, 1] += 2 + rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2 + rz = torch.randint(2, (N, self.nb_rec_max)) + rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1 + n = torch.arange(self.nb_rec_max) + nb_collisions = ( + ( + (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) + & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) + .long() + .flatten(1) + .sum(dim=1) ) - v, h = v[i], h[i] - v = v[: v.size(0) - v.size(0) % nb_rec] - h = h[: h.size(0) - h.size(0) % nb_rec] - v = v.reshape(v.size(0) // nb_rec, nb_rec, -1) - h = h.reshape(h.size(0) // nb_rec, nb_rec, -1) + no_collision = nb_collisions == 0 - r = v[:, :, :, None] * h[:, :, None, :] + if no_collision.any(): + print(no_collision.long().sum() / N) + self.rn = rn[no_collision] + self.ri = ri[no_collision] + self.rj = rj[no_collision] + self.rz = rz[no_collision] + self.rc = rc[no_collision] - valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + nb_contact = ( + self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1) + ) - v = v[valid] - h = h[valid] + self.rcontact = nb_contact > 0 + self.rfree = torch.full((self.rn.size(0),), True) - if v.size(0) > 0: break - av = torch.arange(v.size(2), device=self.device)[None, :] - ah = torch.arange(h.size(2), device=self.device)[None, :] + def get_recworld_state(self): + if not self.rfree.any(): + self.sample_rworld_states() + k = torch.arange(self.rn.size(0))[self.rfree] + k = k[torch.randint(k.size(0), (1,))].item() + self.rfree[k] = False + return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k] - return [ - (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) - for i1, j1, i2, j2 in zip( - v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, - h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, - (v[0] * av).max(dim=-1).values, - (h[0] * ah).max(dim=-1).values, - ) - ] + def draw_state(self, X, rn, ri, rj, rz, rc): + for n in sorted(list(range(rn)), key=lambda n: rz[n].item()): + X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n] - # @torch.compile - def rec_coo_(self, x, n, min_height=3, min_width=3): - collision = x.new(x.size()) - while True: - collision[...] = 0 - result = [] - for _ in range(n): - while True: - i1, i2 = torch.randint(x.size(0), (2,)) - if i1 + min_height <= i2: - break - while True: - j1, j2 = torch.randint(x.size(1), (2,)) - if j1 + min_width <= j2: - break - collision[i1:i2, j1:j2] += 1 - if collision.max() > 1: - break - result.append((i1, j1, i2, j2)) - if collision.max() == 1: - break - return result + def task_recworld_immobile(self, A, f_A, B, f_B): + for X, f_X in [(A, f_A), (B, f_B)]: + rn, ri, rj, rz, rc = self.get_recworld_state() + self.draw_state(X, rn, ri, rj, rz, rc) + ri += 1 + self.draw_state(f_X, rn, ri, rj, rz, rc) ###################################################################### # @torch.compile def task_replace_color(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] + # @torch.compile + def task_symmetry(self, A, f_A, B, f_B): + a, b = torch.randint(2, (2,)) + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + if min([x[2] for x in r]) > self.height // 2 + 1: + break + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + X[: self.height // 2] = 0 + f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2] + if a == 1: + X[...] = X.flip((0,)) + f_X[...] = f_X.flip((0,)) + if b == 1: + X[...] = X.clone().t() + f_X[...] = f_X.clone().t() + # @torch.compile def task_translate(self, A, f_A, B, f_B): - di, dj = torch.randint(3, (2,)) - 1 + while True: + di, dj = torch.randint(3, (2,)) - 1 + if di.abs() + dj.abs() > 0: + break + nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if ( i1 + di >= 0 @@ -353,11 +794,11 @@ class Grids(problem.Problem): def task_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 - direction = torch.randint(2, (1,)) + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + direction = torch.randint(2, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if i1 + 3 < i2 and j1 + 3 < j2: break @@ -376,13 +817,13 @@ class Grids(problem.Problem): f_X[i1:i2, j1:j2] = c[n] # @torch.compile - def task_color_grow(self, A, f_A, B, f_B): + def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 - direction = torch.randint(4, (1,)) + c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1 + direction = torch.randint(4, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[2 * n] @@ -420,27 +861,34 @@ class Grids(problem.Problem): # @torch.compile def task_frame(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - f_X[i1:i2, j1:j2] = c[n] if n == nb_rec - 1: - f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + f_X[i1:i2, j1] = c[n] + f_X[i1:i2, j2 - 1] = c[n] + f_X[i1, j1:j2] = c[n] + f_X[i2 - 1, j1:j2] = c[n] + else: + f_X[i1:i2, j1:j2] = c[n] # @torch.compile def task_detect(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] if n < nb_rec - 1: - f_X[i1, j1] = c[-1] + for k in range(2): + f_X[i1 + k, j1] = c[-1] + f_X[i1, j1 + k] = c[-1] # @torch.compile def contact(self, X, i, j, q): @@ -478,37 +926,66 @@ class Grids(problem.Problem): return no, nq, nq_diag - # @torch.compile - def task_count(self, A, f_A, B, f_B): - N = (torch.randint(4, (1,)) + 2).item() - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + def REMOVED_task_count(self, A, f_A, B, f_B): + while True: + error = False + + N = 3 + c = torch.zeros(N + 2, dtype=torch.int64) + c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_count") or len(self.cache_count) == 0: + self.cache_count = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 8, + nb_iterations=self.height * self.width // 5, + ) + ) - for X, f_X in [(A, f_A), (B, f_B)]: - nb = torch.zeros(N, dtype=torch.int64) - q = torch.randint(N, (self.height * self.width,)) - k = torch.randperm(self.height * self.width) - for p in range(self.height * self.width): - i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) - if no == 0 and nq_diag == 0: - if nq == 0: - if nb[q[p]] < self.width: - X[i, j] = c[q[p]] - nb[q[p]] += 1 - if nq == 1: - X[i, j] = c[q[p]] - - for n in range(N): - for j in range(nb[n]): - f_X[n, j] = c[n] + X[...] = self.cache_count.pop() + + # k = (X.max() + 1 + (c.size(0) - 1)).item() + # V = torch.arange(k) // (c.size(0) - 1) + # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( + # c.size(0) - 1 + # ) + 1 + + V = torch.randint(N, (X.max() + 1,)) + 1 + V[0] = 0 + NB = F.one_hot(c[V]).sum(dim=0) + X[...] = c[V[X]] + f_X[...] = X + + if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3: + m = NB[c[:-1]].max() + if (NB[c[:-1]] == m).long().sum() == 1: + for e in range(1, N + 1): + if NB[c[e]] == m: + a = (f_X == c[e]).long() + f_X[...] = (1 - a) * f_X + a * c[-1] + else: + error = True + break + + if not error: + break + + assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3 # @torch.compile def task_trajectory(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: di, dj = torch.randint(7, (2,)) - 3 - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) if ( abs(di) + abs(dj) > 0 and i + 2 * di >= 0 @@ -532,7 +1009,7 @@ class Grids(problem.Problem): # @torch.compile def task_bounce(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:3] + 1 + c = torch.randperm(self.nb_colors - 1)[:3] + 1 for X, f_X in [(A, f_A), (B, f_B)]: # @torch.compile def free(i, j): @@ -549,8 +1026,9 @@ class Grids(problem.Problem): X[...] = 0 for _ in range((self.height * self.width) // 10): - i, j = torch.randint(self.height, (1,)), torch.randint( - self.width, (1,) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), ) X[i, j] = c[0] f_X[i, j] = c[0] @@ -560,7 +1038,10 @@ class Grids(problem.Problem): if abs(di) + abs(dj) == 1: break - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) X[i, j] = c[1] f_X[i, j] = c[1] @@ -584,6 +1065,7 @@ class Grids(problem.Problem): f_X[i, j] = c[2] if l <= 1: X[i, j] = c[2] + f_X[i, j] = c[1] if l >= self.width: break @@ -596,33 +1078,39 @@ class Grids(problem.Problem): # @torch.compile def task_scale(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 - i, j = torch.randint(self.height // 2, (1,)), torch.randint( - self.width // 2, (1,) + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), ) for X, f_X in [(A, f_A), (B, f_B)]: for _ in range(3): while True: - i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i1, j1 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) - i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i2, j2 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: break X[i + i1 : i + i2, j + j1 : j + j2] = c[0] f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] - X[i, j] = c[1] - f_X[0:2, 0:2] = c[1] + for k in range(2): + X[i + k, j] = c[1] + X[i, j + k] = c[1] + f_X[i + k, j] = c[1] + f_X[i, j + k] = c[1] # @torch.compile def task_symbols(self, A, f_A, B, f_B): nb_rec = 4 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 delta = 3 for X, f_X in [(A, f_A), (B, f_B)]: while True: @@ -634,34 +1122,45 @@ class Grids(problem.Problem): if d.min() > delta: break - for k in range(1, nb_rec): - X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] - ai, aj = i.float().mean(), j.float().mean() - q = torch.randint(3, (1,)) + 1 - - X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + q = torch.randint(3, (1,)).item() + 1 assert i[q] != ai and j[q] != aj - X[ + for Z in [X, f_X]: + for k in range(0, nb_rec): + Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] + # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] + # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] + # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] + # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + + # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + + f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q] + # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + + ii, jj = ( i[0] + delta // 2 + (i[q] - ai).sign().long(), j[0] + delta // 2 + (j[q] - aj).sign().long(), - ] = c[nb_rec] + ) + + X[ii, jj] = c[nb_rec] + X[i[0] + delta // 2, jj] = c[nb_rec] + X[ii, j[0] + delta // 2] = c[nb_rec] - f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + f_X[ii, jj] = c[nb_rec] + f_X[i[0] + delta // 2, jj] = c[nb_rec] + f_X[ii, j[0] + delta // 2] = c[nb_rec] # @torch.compile - def task_ortho(self, A, f_A, B, f_B): + def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 di, dj = torch.randint(3, (2,)) - 1 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) m = torch.eye(2) - for _ in range(torch.randint(4, (1,))): + for _ in range(torch.randint(4, (1,)).item()): m = m @ o if torch.rand(1) < 0.5: m[0, :] = -m[0, :] @@ -673,7 +1172,7 @@ class Grids(problem.Problem): X[...] = 0 f_X[...] = 0 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for r in range(nb_rec): while True: @@ -710,9 +1209,88 @@ class Grids(problem.Problem): ): break + def compute_distance(self, walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + dist[1:-1, 1:-1] = ( + torch.cat( + ( + dist[None, 1:-1, 1:-1], + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + # @torch.compile - def task_islands(self, A, f_A, B, f_B): - pass + def REMOVED_task_distance(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:3] + 1 + dist0 = torch.empty(self.height + 2, self.width + 2) + dist1 = torch.empty(self.height + 2, self.width + 2) + for X, f_X in [(A, f_A), (B, f_B)]: + nb_rec = torch.randint(3, (1,)).item() + 1 + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + X[...] = 0 + f_X[...] = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[0] + f_X[i1:i2, j1:j2] = c[0] + while True: + i0, j0 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i0, j0] == 0: + break + while True: + i1, j1 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i1, j1] == 0: + break + dist1[...] = 1 + dist1[1:-1, 1:-1] = (X != 0).long() + dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1) + if ( + dist1[i0 + 1, j0 + 1] >= 1 + and dist1[i0 + 1, j0 + 1] < self.height * 4 + ): + break + + dist0[...] = 1 + dist0[1:-1, 1:-1] = (X != 0).long() + dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1) + + dist0 = dist0[1:-1, 1:-1] + dist1 = dist1[1:-1, 1:-1] + + D = dist1[i0, j0] + for d in range(1, D): + M = (dist0 == d) & (dist1 == D - d) + f_X[...] = (1 - M) * f_X + M * c[1] + + X[i0, j0] = c[2] + f_X[i0, j0] = c[2] + X[i1, j1] = c[2] + f_X[i1, j1] = c[2] # for X, f_X in [(A, f_A), (B, f_B)]: # n = torch.arange(self.height * self.width).reshape(self.height, self.width) @@ -722,80 +1300,528 @@ class Grids(problem.Problem): # i,j=q%self.height,q//self.height # if - ###################################################################### + # @torch.compile + def TOO_HARD_task_puzzle(self, A, f_A, B, f_B): + S = 4 + i0, j0 = (self.height - S) // 2, (self.width - S) // 2 + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + f_X[...] = 0 + h = list(torch.randperm(c.size(0))) + n = torch.zeros(c.max() + 1) + for _ in range(2): + k = torch.randperm(S * S) + for q in k: + i, j = q % S + i0, q // S + j0 + if f_X[i, j] == 0: + r, s, t, u = ( + f_X[i - 1, j], + f_X[i, j - 1], + f_X[i + 1, j], + f_X[i, j + 1], + ) + r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)] + if r > 0 and n[r] < 6: + n[r] += 1 + f_X[i, j] = r + elif s > 0 and n[s] < 6: + n[s] += 1 + f_X[i, j] = s + elif t > 0 and n[t] < 6: + n[t] += 1 + f_X[i, j] = t + elif u > 0 and n[u] < 6: + n[u] += 1 + f_X[i, j] = u + else: + if len(h) > 0: + d = c[h.pop()] + n[d] += 1 + f_X[i, j] = d + + if n.sum() == S * S: + break - def all_tasks(self): - return [ - self.task_replace_color, - self.task_translate, - self.task_grow, - self.task_color_grow, - self.task_frame, - self.task_detect, - self.task_count, - self.task_trajectory, - self.task_bounce, - self.task_scale, - self.task_symbols, - self.task_ortho, - # self.task_islands, - ] + k = 0 + for d in range(4): + while True: + ii, jj = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + e = 0 + for i in range(S): + for j in range(S): + if ( + ii + i >= self.height + or jj + j >= self.width + or ( + f_X[i + i0, j + j0] == c[d] + and X[ii + i, jj + j] > 0 + ) + ): + e = 1 + if e == 0: + break + for i in range(S): + for j in range(S): + if f_X[i + i0, j + j0] == c[d]: + X[ii + i, jj + j] = c[d] + + def TOO_MESSY_task_islands(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: + self.cache_islands = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 20, + nb_iterations=self.height * self.width // 2, + ) + ) + + A = self.cache_islands.pop() + + while True: + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), + ) + if A[i, j] > 0: + break + + X[...] = (A > 0) * c[0] + f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0] + f_X[i, j] = X[i, j] + X[i, j] = c[1] + + # @torch.compile + def TOO_HARD_task_stack(self, A, f_A, B, f_B): + N = 5 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + i1, j1, i2, j2 = ( + self.height // 2 - 1, + self.width // 2 - 1, + self.height // 2 + 1, + self.width // 2 + 1, + ) + op = torch.tensor((0, 1, 2, 3) * 4) + op = op[torch.randperm(op.size(0))[:9]] + for q in range(op.size(0)): + u = 3 * (q // 3) + v = 3 * (q % 3) + d = c[torch.randint(N, (1,)).item()] + # X[u+1,v+1]=d + if op[q] == 0: # right + X[u : u + 3, v + 2] = d + elif op[q] == 1: # let + X[u : u + 3, v] = d + elif op[q] == 2: # bottom + X[u + 2, v : v + 3] = d + elif op[q] == 3: # top + X[u, v : v + 3] = d + + if q == 0: + f_X[i1:i2, j1:j2] = d + elif op[q] == 0: # right + f_X[i1:i2, j2] = d + j2 += 1 + elif op[q] == 1: # let + j1 -= 1 + f_X[i1:i2, j1] = d + elif op[q] == 2: # bottom + f_X[i2, j1:j2] = d + i2 += 1 + elif op[q] == 3: # top + i1 -= 1 + f_X[i1, j1:j2] = d + + def randint(self, *m): + m = torch.tensor(m) + return (torch.rand(m.size()) * m).long() + + def TOO_HARD_task_matrices(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + M1 = torch.randint(2, (5, 5)) + M2 = torch.randint(2, (5, 5)) + P = M1 @ M2 + for i in range(5): + for j in range(5): + X[i, j] = c[M1[i, j]] + X[i, j + 5] = c[M2[i, j]] + f_X[i, j] = c[M1[i, j]] + f_X[i, j + 5] = c[M2[i, j]] + f_X[i + 5, j + 5] = c[P[i, j]] + + def TOO_HARD_task_compute(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + v = torch.randint((self.width - 1) // 2, (N,)) + 1 + chain = torch.randperm(N) + eq = [] + for i in range(chain.size(0) - 1): + i1, i2 = chain[i], chain[i + 1] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + eq.append((c[i1], w1, c[i2], w2)) + + ii = torch.randperm(self.height - 2)[: len(eq)] + + for k, x in enumerate(eq): + i = ii[k] + c1, w1, c2, w2 = x + s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item() + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + w2] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 + + i1, i2 = torch.randperm(N)[:2] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + c1, c2 = c[i1], c[i2] + s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item() + i = self.height - 1 + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + 1] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 + + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_contact(self, A, f_A, B, f_B): + def rec_dist(a, b): + ai1, aj1, ai2, aj2 = a + bi1, bj1, bi2, bj2 = b + v = max(ai1 - bi2, bi1 - ai2) + h = max(aj1 - bj2, bj1 - aj2) + return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0)) + + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + d = [rec_dist(r[0], r[k]) for k in range(nb_rec)] + if min(d[1:]) == 0: + break + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if d[n] == 0: + f_X[i1, j1:j2] = c[0] + f_X[i2 - 1, j1:j2] = c[0] + f_X[i1:i2, j1] = c[0] + f_X[i1:i2, j2 - 1] = c[0] + + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_corners(self, A, f_A, B, f_B): + polarity = torch.randint(2, (1,)).item() + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(nb_rec, prevent_overlap=True) + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + for k in range(2): + if polarity == 0: + X[i1 + k, j1] = c[n] + X[i2 - 1 - k, j2 - 1] = c[n] + X[i1, j1 + k] = c[n] + X[i2 - 1, j2 - 1 - k] = c[n] + else: + X[i1 + k, j2 - 1] = c[n] + X[i2 - 1 - k, j1] = c[n] + X[i1, j2 - 1 - k] = c[n] + X[i2 - 1, j1 + k] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + def compdist(self, X, i, j): + dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width) + d = dd[1:-1, 1:-1] + m = (X > 0).long() + d[i, j] = 0 + e = d.clone() + while True: + e[...] = d + d[...] = ( + d.min(dd[:-2, 1:-1] + 1) + .min(dd[2:, 1:-1] + 1) + .min(dd[1:-1, :-2] + 1) + .min(dd[1:-1, 2:] + 1) + ) + d[...] = (1 - m) * d + m * self.height * self.width + if e.equal(d): + break + + return d + + # @torch.compile + def task_path(self, A, f_A, B, f_B): + nb_rec = 2 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + i1, i2 = torch.randint(self.height, (2,)) + j1, j2 = torch.randint(self.width, (2,)) + if ( + abs(i1 - i2) + abs(j1 - j2) > 2 + and X[i1, j1] == 0 + and X[i2, j2] == 0 + ): + d2 = self.compdist(X, i2, j2) + d = self.compdist(X, i1, j1) + + if d2[i1, j1] < 2 * self.width: + break + + m = ((d + d2) == d[i2, j2]).long() + f_X[...] = m * c[-1] + (1 - m) * f_X + + X[i1, j1] = c[-2] + X[i2, j2] = c[-2] + f_X[i1, j1] = c[-2] + f_X[i2, j2] = c[-2] + + # @torch.compile + def task_fill(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + accept_full = torch.rand(1) < 0.5 + + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + while True: + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i, j] == 0: + break + + d = self.compdist(X, i, j) + m = (d < self.height * self.width).long() + X[i, j] = c[-1] + f_X[...] = m * c[-1] + (1 - m) * f_X + f_X[i, j] = 0 + + if accept_full or (d * (X == 0)).max() == self.height * self.width: + break + + def TOO_HARD_task_addition(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + S = N1 + N2 + for j in range(self.width): + r1 = (N1 // (2**j)) % 2 + X[0, -j - 1] = c[r1] + f_X[0, -j - 1] = c[r1] + r2 = (N2 // (2**j)) % 2 + X[1, -j - 1] = c[r2] + f_X[1, -j - 1] = c[r2] + rs = (S // (2**j)) % 2 + f_X[2, -j - 1] = c[2 + rs] + + def task_science_implicit(self, A, f_A, B, f_B): + nb_rec = 5 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i1, i2 = torch.randint(self.height, (2,)).sort().values + if i1 >= 1 and i2 < self.height and i1 + 3 < i2: + break + + while True: + j1, j2 = torch.randint(self.width, (2,)).sort().values + if j1 >= 1 and j2 < self.width and j1 + 3 < j2: + break + + f_X[i1:i2, j1:j2] = c[0] + + # --------------------- - def trivial_prompts_and_answers(self, prompts, answers): + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(j1, (1,)) + X[ii1:ii2, jj:j1] = c[1] + f_X[ii1:ii2, jj:j1] = c[1] + + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(self.width - j2, (1,)) + j2 + 1 + X[ii1:ii2, j2:jj] = c[2] + f_X[ii1:ii2, j2:jj] = c[2] + + # --------------------- + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(i1, (1,)) + X[ii:i1, jj1:jj2] = c[3] + f_X[ii:i1, jj1:jj2] = c[3] + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(self.height - i2, (1,)) + i2 + 1 + X[i2:ii, jj1:jj2] = c[4] + f_X[i2:ii, jj1:jj2] = c[4] + + def task_science_dot(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + r = self.rec_coo(nb_rec, prevent_overlap=True) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + q = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if i >= i1 and i < i2: + q += 1 + f_X[i, j1:j2] = c[-1] + if j >= j1 and j < j2: + q += 1 + f_X[i1:i2, j] = c[-1] + X[i, j] = c[-1] + f_X[i, j] = c[-1] + if q >= 2: + break + + def collide(self, s, r, rs): + i, j = r + for i2, j2 in rs: + if abs(i - i2) < s and abs(j - j2) < s: + return True + return False + + def task_science_tag(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + rs = [] + while len(rs) < 4: + i, j = ( + torch.randint(self.height - 3, (1,)).item(), + torch.randint(self.width - 3, (1,)).item(), + ) + if not self.collide(s=3, r=(i, j), rs=rs): + rs.append((i, j)) + + for k in range(len(rs)): + i, j = rs[k] + q = min(k, 2) + X[i, j : j + 3] = c[q] + X[i + 2, j : j + 3] = c[q] + X[i : i + 3, j] = c[q] + X[i : i + 3, j + 2] = c[q] + + f_X[i, j : j + 3] = c[q] + f_X[i + 2, j : j + 3] = c[q] + f_X[i : i + 3, j] = c[q] + f_X[i : i + 3, j + 2] = c[q] + if q == 2: + f_X[i + 1, j + 1] = c[-1] + + # end_tasks + + ###################################################################### + + def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")): S = self.height * self.width - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] - f_Bs = answers - return (Bs == f_Bs).long().min(dim=-1).values > 0 + quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64) + quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]] + quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]] + quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]] + quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]] - def generate_prompts_and_answers( - self, nb, tasks=None, progress_bar=False, device="cpu" - ): - if tasks is None: - tasks = self.all_tasks() + return quizzes + def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): S = self.height * self.width - prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) - answers = torch.zeros(nb, S, dtype=torch.int64) - bunch = zip(prompts, answers) + if tasks is None: + tasks = self.all_tasks + + quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64) if progress_bar: - bunch = tqdm.tqdm( - bunch, + quizzes = tqdm.tqdm( + quizzes, dynamic_ncols=True, - desc="world generation", - total=prompts.size(0), + desc="world quizzes generation", + total=quizzes.size(0), ) - for prompt, answer in bunch: - A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) - f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) - B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) - f_B = answer.view(self.height, self.width) - task = tasks[torch.randint(len(tasks), (1,))] + for quiz in quizzes: + q = quiz.reshape(4, self.height, self.width) + q[...] = 0 + A, f_A, B, f_B = q + task = tasks[torch.randint(len(tasks), (1,)).item()] task(A, f_A, B, f_B) - return prompts.flatten(1), answers.flatten(1) + return quizzes - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - nrow=4, - ): - self.save_image( - result_dir, - filename_prefix + ".png", - prompts, - answers, - predicted_prompts, - predicted_answers, - nrow, - ) + def save_some_examples(self, result_dir, prefix=""): + nb, nrow = 256, 8 + for t in self.all_tasks: + print(t.__name__) + quizzes = self.generate_w_quizzes_(nb, tasks=[t]) + self.save_quizzes_as_image( + result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow + ) ###################################################################### @@ -803,37 +1829,121 @@ class Grids(problem.Problem): if __name__ == "__main__": import time + # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) + grids = Grids() - # nb = 1000 - # grids = problem.MultiThreadProblem( - # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1 - # ) - # time.sleep(10) - # start_time = time.perf_counter() - # prompts, answers = grids.generate_prompts_and_answers(nb) - # delay = time.perf_counter() - start_time - # print(f"{prompts.size(0)/delay:02f} seq/s") - # exit(0) - - if True: - nb = 72 - - for t in grids.all_tasks(): - # for t in [grids.task_ortho]: - print(t.__name__) - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + nb, nrow = 64, 4 + # nb, nrow = 8, 2 + + # for t in grids.all_tasks: + + for t in [ + grids.task_replace_color, + grids.task_translate, + grids.task_grow, + grids.task_frame, + ]: + print(t.__name__) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + + # w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size()) + + grids.save_quizzes_as_image( + "/tmp", + t.__name__ + ".png", + w_quizzes, + delta=True, + # grids=False + # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], + ) + + exit(0) - exit(0) + q = grids.text2quiz( + """ + +# the original + +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +....aaaaa. ....aaaaa. .vvvvv.... .rrrrr.... +.......... .......... .vvvvvvvvv .rrrrroooo +.......... .......... .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. .......aaa .......aaa +vvvvaaaaa. rrrraaaaa. .......aaa .......aaa +....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa +.......... .......... .vvvvvvvvv .rrrrroooo +.......... .......... .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +# +# so what +# + +vvvv...... rrrr...... .......... .......... +vvvv...... rrrr...... .......... .......... +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +vvvv...... rrrr...... .......... .......... +vvvv...... rrrr...... .......... .......... +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +""" + ) + + grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False) - nb = 500 + exit(0) - for t in grids.all_tasks(): + nb = 1000 + + for t in [ + # grids.task_bounce, + # grids.task_contact, + # grids.task_corners, + # grids.task_detect, + # grids.task_fill, + # grids.task_frame, + # grids.task_grow, + # grids.task_half_fill, + # grids.task_isometry, + # grids.task_path, + # grids.task_replace_color, + # grids.task_scale, + grids.task_symbols, + # grids.task_trajectory, + # grids.task_translate, + ]: + # for t in [grids.task_path]: start_time = time.perf_counter() - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time - print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s") + print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s") + grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128]) exit(0) @@ -841,9 +1951,9 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quizzes( + grids.save_quizzes_as_image( "/tmp", - "test", + "test.png", prompts[:nb], answers[:nb], # You can add a bool to put a frame around the predicted parts diff --git a/main.py b/main.py index 9c3d7f1..5dceefc 100755 --- a/main.py +++ b/main.py @@ -5,32 +5,25 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, os, datetime, warnings +import math, sys, argparse, time, tqdm, os, datetime, warnings, copy import torch, torchvision from torch import nn from torch.nn import functional as F -import ffutils +import ffutils, grids, attae -import mygpt -import sky, grids, quiz_machine -from problem import MultiThreadProblem +import threading, subprocess -# world quizzes vs. culture quizzes +# import torch.multiprocessing as mp -###################################################################### +torch.set_float32_matmul_precision("high") -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") +# torch.set_default_dtype(torch.bfloat16) ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -40,25 +33,39 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) -parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--resume", action="store_true", default=False) -######################################## +# ---------------------------------- parser.add_argument("--nb_epochs", type=int, default=10000) -parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--batch_size", type=int, default=25) + +parser.add_argument("--train_batch_size", type=int, default=None) -parser.add_argument("--physical_batch_size", type=int, default=None) +parser.add_argument("--eval_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=None) +parser.add_argument("--nb_train_samples", type=int, default=50000) -parser.add_argument("--nb_test_samples", type=int, default=None) +parser.add_argument("--nb_test_samples", type=int, default=2500) + +parser.add_argument("--nb_c_quizzes", type=int, default=5000) + +parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) -######################################## +parser.add_argument("--nb_have_to_be_correct", type=int, default=3) + +parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) -parser.add_argument("--model", type=str, default=None) +parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) + +# ---------------------------------- + +parser.add_argument("--model_type", type=str, default="standard") + +parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -70,72 +77,52 @@ parser.add_argument("--nb_heads", type=int, default=None) parser.add_argument("--nb_blocks", type=int, default=None) -parser.add_argument("--dropout", type=float, default=0.1) +parser.add_argument("--dropout", type=float, default=0.5) -######################################## +# ---------------------------------- -parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--nb_threads", type=int, default=1) -parser.add_argument("--problem", type=str, default="grids") +parser.add_argument("--gpus", type=str, default="all") -parser.add_argument("--multi_thread_problem", action="store_true", default=False) +# ---------------------------------- -parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--nb_models", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=None) +parser.add_argument("--diffusion_nb_iterations", type=int, default=25) -parser.add_argument("--max_to_validate", type=int, default=None) +parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--generation_temperature", type=float, default=2.0) +parser.add_argument("--proba_prompt_noise", type=float, default=0.05) -parser.add_argument("--deterministic_validation", action="store_true", default=False) +parser.add_argument("--proba_hint", type=float, default=0.25) -parser.add_argument("--bidirectional_validation", action="store_true", default=False) - -parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument("--quizzes", type=str, default=None) ###################################################################### -parser.add_argument("--sky_height", type=int, default=6) - -parser.add_argument("--sky_width", type=int, default=8) - -parser.add_argument("--sky_nb_birds", type=int, default=3) - -parser.add_argument("--sky_nb_iterations", type=int, default=2) +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) -parser.add_argument("--sky_speed", type=int, default=3) +parser.add_argument( + "--grids_world_tasks", + type=str, + default="replace_color,translate,grow,frame", + help="A comma-separated subset of: " + grids_tasks + ".", +) ###################################################################### args = parser.parse_args() -if args.min_to_validate is None: - args.min_to_validate = args.nb_gpts - 1 - -if args.max_to_validate is None: - args.max_to_validate = args.nb_gpts - 1 - if args.result_dir is None: args.result_dir = f"results_culture" ###################################################################### -default_args = { - "model": "37M", - "batch_size": 100, - "nb_train_samples": 100000, - "nb_test_samples": 10000, -} - -for k, v in default_args.items(): - if getattr(args, k) is None: - setattr(args, k, v) - -###################################################################### - default_model_args = { "17K": { "dim_model": 32, @@ -183,11 +170,16 @@ else: ###################################################################### -try: - os.mkdir(args.result_dir) -except FileExistsError: - print(f"result directory {args.result_dir} already exists") - exit(1) +if args.resume: + if not os.path.isdir(args.result_dir): + print(f"Trying to resume from a non-existing result dir {args.result_dir}.") + exit(1) +else: + try: + os.mkdir(args.result_dir) + except FileExistsError: + print(f"result directory {args.result_dir} already exists") + exit(1) log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -203,6 +195,9 @@ if args.seed >= 0: def log_string(s): + """print the given string prefixed with a time stamps, and log it + into log_file is not None""" + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) if log_file is not None: @@ -213,6 +208,18 @@ def log_string(s): sys.stdout.flush() +###################################################################### +# Create a time-stamped archive of the source code + +with open("this_run.sh", "w") as f: + f.write(f"{' '.join(sys.argv)}\n") + +now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) + +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") + +###################################################################### + log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): @@ -221,366 +228,773 @@ for n in vars(args): ###################################################################### -if args.dirty_debug: - args.nb_train_samples = 2500 - args.nb_test_samples = 100 +if args.gpus == "all": + gpus_idx = range(torch.cuda.device_count()) +else: + gpus_idx = [int(k) for k in args.gpus.split(",")] + +gpus = [torch.device(f"cuda:{n}") for n in gpus_idx] + +if torch.cuda.is_available(): + main_device = gpus[0] +else: + assert len(gpus) == 0 + main_device = torch.device("cpu") -if args.physical_batch_size is None: - args.physical_batch_size = args.batch_size +if args.train_batch_size is None: + args.train_batch_size = args.batch_size else: - assert args.batch_size % args.physical_batch_size == 0 + assert args.batch_size % args.train_batch_size == 0 assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 -if args.problem == "sky": - problem = sky.Sky( - height=args.sky_height, - width=args.sky_width, - nb_birds=args.sky_nb_birds, - nb_iterations=args.sky_nb_iterations, - speed=args.sky_speed, - ) - back_accuracy = False -elif args.problem == "grids": - problem = grids.Grids(device=device) - back_accuracy = True -else: - raise ValueError - -if args.multi_thread_problem: - problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000) - -quiz_machine = quiz_machine.QuizMachine( - problem=problem, - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - back_accuracy=back_accuracy, - batch_size=args.physical_batch_size, - result_dir=args.result_dir, - logger=log_string, - device=device, -) +###################################################################### + + +def optimizer_to(optim, device): + """Move the optimizer optim to the device""" + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + ###################################################################### -log_string(f"device {device}") -vocabulary_size = quiz_machine.vocabulary_size() +def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): + if c_quizzes is None: + quizzes = problem.generate_w_quizzes(nb_samples) + nb_w_quizzes = quizzes.size(0) + nb_c_quizzes = 0 + else: + if c_quiz_multiplier > 1: + n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) + body = c_quizzes.repeat(n, 1) + if n < c_quiz_multiplier: + tail = c_quizzes[ + torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)] + ] + c_quizzes = torch.cat([body, tail], dim=0) + else: + c_quizzes = body + + if c_quizzes.size(0) > nb_samples // 2: + i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] + c_quizzes = c_quizzes[i] + + w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) + + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + nb_w_quizzes = w_quizzes.size(0) + nb_c_quizzes = c_quizzes.size(0) + + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes = quizzes[i].contiguous() + + log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}") + + return quizzes -log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -# Compute the entropy of the training tokens -token_count = 0 -for input in quiz_machine.batches(split="train", desc="train-entropy"): - token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum( - (0, 1) - ) -token_probas = token_count / token_count.sum() -entropy = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(entropy) +def add_hints_imt(imt_set): + """Set every component of the mask to zero with probability + args.proba_hint, and for each component set to zero, copy the + corresponding value from the target into the input + + """ + input, masks, targets = imt_set.unbind(dim=1) + # h = torch.rand(masks.size(), device=masks.device) - masks + # t = h.sort(dim=1).values[:, args.nb_hints, None] + # mask_hints = (h < t).long() + mask_hints = ( + torch.rand(input.size(), device=input.device) < args.proba_hint + ).long() * masks + masks = (1 - mask_hints) * masks + input = (1 - mask_hints) * input + mask_hints * targets + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) + + +def add_noise_imt(imt_set): + """Replace every component of the input by a random value with + probability args.proba_prompt_noise.""" + input, masks, targets = imt_set.unbind(dim=1) + noise = problem.pure_noise(input.size(0), input.device) + change = (1 - masks) * ( + torch.rand(input.size(), device=input.device) < args.proba_prompt_noise + ).long() + input = (1 - change) * input + change * noise + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) + ###################################################################### -# A bit of paranoia never hurts - -if args.max_percents_of_test_in_train >= 0: - - def subsets_as_tuples(batches, cs): - s = set() - for batch in batches: - for x in batch: - s.add(tuple([v.item() for v in x])) - if len(s) == cs: - yield s - s = set() - yield s - - nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples( - quiz_machine.batches(split="test", desc="test-check"), 25000 - ): - in_train = set() - for train_subset in subsets_as_tuples( - quiz_machine.batches(split="train", desc="train-check"), 25000 - ): - in_train.update(test_subset.intersection(train_subset)) - nb_in_train += len(in_train) - nb_test += len(test_subset) +# Prediction - log_string( - f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" - ) - assert ( - nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 - ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" +def samples_for_prediction_imt(input): + nb = input.size(0) + masks = input.new_zeros(input.size()) + u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) + masks.view(nb, 4, -1)[...] = u[:, :, None] + targets = input + input = (1 - masks) * targets + + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -############################## +def ae_predict(model, imt_set, local_device=main_device): + model.eval().to(local_device) -def one_epoch(model, quiz_machine): - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + record = [] - model.train() + src = tqdm.tqdm( + imt_set.split(args.eval_batch_size), + dynamic_ncols=True, + desc="predict", + total=imt_set.size(0) // args.eval_batch_size, + delay=10, + ) - nb_train_samples, acc_train_loss = 0, 0.0 + for imt in src: + # some paranoia + imt = imt.clone() + imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) - for input in quiz_machine.batches(split="train"): - input = input.to(device) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(imt[:, 0] * 2 + imt[:, 1]) + dist = torch.distributions.categorical.Categorical(logits=logits) + result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample() + record.append(result) - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() + return torch.cat(record) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) +def predict_the_four_grids( + model, input, with_noise=False, with_hints=False, local_device=main_device +): + input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) + nb = input.size(0) + masks = input.new_zeros(input.size()) + u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4) + masks.view(nb, 4, -1)[...] = u[:, :, None] + targets = input + input = (1 - masks) * targets + imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - loss.backward() + if with_hints: + imt_set = add_hints_imt(imt_set) - if nb_train_samples % args.batch_size == 0: - optimizer.step() + if with_noise: + imt_set = add_noise_imt(imt_set) - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + result = ae_predict(model, imt_set, local_device=local_device) + result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) - log_string(f"train_perplexity {n_epoch} {train_perplexity}") + return result ###################################################################### -def run_tests(model, quiz_machine, deterministic_synthesis): - with torch.autograd.no_grad(): - model.eval() +def samples_for_generation_imt(input): + nb = input.size(0) + probs_iterations = 0.1 ** torch.linspace( + 0, 1, args.diffusion_nb_iterations, device=input.device + ) + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + probs_iterations = probs_iterations.expand(nb, -1) + dist = torch.distributions.categorical.Categorical(probs=probs_iterations) + t = dist.sample() + 1 + r = torch.rand(input.size(), device=input.device) + proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t + mask_erased = (r <= proba_erased[:, None]).long() - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 + noise = problem.pure_noise(nb, input.device) + targets = input + input = (1 - mask_erased) * input + mask_erased * noise + masks = input.new_full(input.size(), 1) - for input in quiz_machine.batches(split="test"): - input = input.to(device) + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - loss = F.cross_entropy(output.transpose(1, 2), input) +def prioritized_rand(low): + x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values + k = torch.rand(low.size(), device=low.device) + low.long() + k = k.sort(dim=1).indices + y = x.new(x.size()) + y.scatter_(dim=1, index=k, src=x) + return y - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) +def ae_generate(model, nb, local_device=main_device): + model.eval().to(local_device) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + # We loop through the iterations first and through the + # mini-batches second so that we keep only the samples that have + # not stabilized - log_string(f"test_perplexity {n_epoch} {test_perplexity}") + all_input = problem.pure_noise(nb, local_device) + all_masks = all_input.new_full(all_input.size(), 1) + all_changed = torch.full((all_input.size(0),), True, device=all_input.device) - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, + for it in range(args.diffusion_nb_iterations): + # log_string(f"nb_changed {all_changed.long().sum().item()}") + + if not all_changed.any(): + break + + sub_input = all_input[all_changed].clone() + sub_masks = all_masks[all_changed].clone() + sub_changed = all_changed[all_changed].clone() + + src = zip( + sub_input.split(args.eval_batch_size), + sub_masks.split(args.eval_batch_size), + sub_changed.split(args.eval_batch_size), + ) + + for input, masks, changed in src: + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(input * 2 + masks) + dist = torch.distributions.categorical.Categorical(logits=logits) + output = dist.sample() + r = prioritized_rand(input != output) + mask_changes = (r <= args.diffusion_proba_corruption).long() * masks + update = (1 - mask_changes) * input + mask_changes * output + changed[...] = changed & (update != input).max(dim=1).values + input[...] = update + + a = all_changed.clone() + all_input[a] = sub_input + all_masks[a] = sub_masks + all_changed[a] = sub_changed + + return all_input + + +###################################################################### + + +def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): + quizzes = generate_quiz_set( + args.nb_train_samples if train else args.nb_test_samples, + c_quizzes, + args.c_quiz_multiplier, + ) + + q_p, q_g = quizzes.to(local_device).chunk(2) + + # Half of the samples train the prediction, and we inject noise in + # all, and hints in half + b_p = samples_for_prediction_imt(q_p) + b_p = add_noise_imt(b_p) + half = torch.rand(b_p.size(0)) < 0.5 + b_p[half] = add_hints_imt(b_p[half]) + + # The other half are denoising examples for the generation + b_g = samples_for_generation_imt(q_g) + + imt_set = torch.cat([b_p, b_g]) + imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] + + if train: + label = "train" + model.train().to(local_device) + optimizer_to(model.optimizer, local_device) + batch_size = args.train_batch_size + else: + label = "test" + model.eval().to(local_device) + batch_size = args.eval_batch_size + + nb_samples, acc_loss = 0, 0.0 + + for imt in tqdm.tqdm( + imt_set.split(batch_size), + dynamic_ncols=True, + desc=label, + total=quizzes.size(0) // batch_size, + delay=10, + ): + input, masks, targets = imt.unbind(dim=1) + if train and nb_samples % args.batch_size == 0: + model.optimizer.zero_grad() + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(input * 2 + masks) + + loss_per_token = F.cross_entropy( + logits.transpose(1, 2), targets, reduction="none" ) + loss = (loss_per_token * masks).mean() + acc_loss += loss.item() * imt.size(0) + nb_samples += imt.size(0) + + if train: + loss.backward() + + if nb_samples % args.batch_size == 0: + model.optimizer.step() + + log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") ###################################################################### -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(c)] for q, c in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) +def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device): + # Save some images of the prediction results + + quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier) + imt_set = samples_for_prediction_imt(quizzes.to(local_device)) + result = ae_predict(model, imt_set, local_device=local_device).to("cpu") + masks = imt_set[:, 1].to("cpu") + + correct = (quizzes == result).min(dim=1).values.long() + correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[ + :, :, 1 + ] + predicted_parts = correct_parts.abs() + + problem.save_quizzes_as_image( + args.result_dir, + f"culture_prediction_{n_epoch}_{model.id}.png", + quizzes=result[:128], + predicted_parts=predicted_parts[:128], + correct_parts=correct_parts[:128], + ) + + # Save some images of the ex nihilo generation of the four grids + + result = ae_generate(model, 150, local_device=local_device).to("cpu") + problem.save_quizzes_as_image( + args.result_dir, + f"culture_generation_{n_epoch}_{model.id}.png", + quizzes=result[:128], + ) ###################################################################### -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, +def one_complete_epoch( + model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device ): - quizzes_and_nb_correct_records = [] + one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device) + + one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device) + + # Compute the test accuracy + + quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) + imt_set = samples_for_prediction_imt(quizzes.to(local_device)) + result = ae_predict(model, imt_set, local_device=local_device).to("cpu") + correct = (quizzes == result).min(dim=1).values.long() - nb_to_create = nb_for_train + nb_for_test + nb_correct, nb_total = correct.sum().item(), quizzes.size(0) + model.test_accuracy = nb_correct / nb_total - # ------------------------------------------------------------ + log_string( + f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)" + ) - standard_validity = lambda nb_correct: torch.logical_and( - nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate + save_inference_images( + model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device ) - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes +###################################################################### + - model_for_generation = models[torch.randint(len(models), (1,))] +def max_nb_mistakes_on_one_grid(quizzes, prediction): + return ( + (prediction != quizzes) + .long() + .reshape(quizzes.size(0), 4, -1) + .sum(dim=2) + .max(dim=1) + .values + ) - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] +def evaluate_quizzes(quizzes, models, with_hints, local_device): + nb_correct, nb_wrong = 0, 0 - if c_quizzes.size(0) > 0: - nb_correct, seq_logproba = quiz_machine.compute_correctness( - c_quizzes, - models, - bidirectional_validation=args.bidirectional_validation, - deterministic_validation=args.deterministic_validation, - ) + for model in models: + model = copy.deepcopy(model).to(local_device).eval() + predicted = predict_the_four_grids( + model=model, + input=quizzes, + with_noise=False, + with_hints=with_hints, + local_device=local_device, + ) + nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted) + nb_correct += (nb_mistakes == 0).long() + nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long() + + # print("\n\n", nb_correct, nb_wrong) + + return nb_correct, nb_wrong + + +###################################################################### - for n, l in zip(nb_correct, seq_logproba): - s = " ".join([str(x.item()) for x in l]) - logp_file.write(f"{n} {s}\n") - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device - ) +def identity_quizzes(quizzes): + quizzes = quizzes.reshape(quizzes.size(0), 4, -1) + return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | ( + quizzes[:, 2] == quizzes[:, 3] + ).min(dim=1).values - quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) - nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) - nv = " ".join([str(x.item()) for x in nv]) +def generate_c_quizzes(models, nb_to_generate, local_device=main_device): + record = [] + nb_validated = 0 + + start_time = time.perf_counter() + last_log = -1 + + while nb_validated < nb_to_generate: + # Generate new quizzes + + model = models[torch.randint(len(models), (1,)).item()] + model = copy.deepcopy(model).to(local_device).eval() + generator_id = model.id + + c_quizzes = ae_generate( + model=model, nb=args.eval_batch_size * 10, local_device=local_device + ) + + c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False] - nb_validated = valid_c_quizzes( - quizzes_and_nb_correct_records, standard_validity - ).size(0) + if c_quizzes.size(0) > 0: + # Select the ones that are solved properly by some models and + # not understood by others + + nb_correct, nb_wrong = evaluate_quizzes( + quizzes=c_quizzes, + models=models, + with_hints=True, + local_device=local_device, + ) + + to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( + nb_wrong >= args.nb_have_to_be_wrong + ) + + nb_validated += to_keep.long().sum().item() + record.append(c_quizzes[to_keep]) + + ##################### + + duration = time.perf_counter() - start_time + + if last_log < 0 or duration > last_log + 10: + last_log = duration + if nb_validated > 0: + if nb_validated < nb_to_generate: + d = (nb_to_generate - nb_validated) * duration / nb_validated + e = ( + datetime.datetime.now() + datetime.timedelta(seconds=d) + ).strftime("%a %H:%M") + else: + e = "now!" + else: + e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" + f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)" ) - # store the new c_quizzes which have been validated + ##################### + + duration = time.perf_counter() - start_time - new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity) + log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h") - quiz_machine.reverse_random_half_in_place(new_c_quizzes) + return torch.cat(record).to("cpu") + + +###################################################################### - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like +def multithread_execution(fun, arguments): + # Single instance, no thread + if len(arguments) == 1: + return fun(*(arguments[0])) - for n in range(len(models) + 1): - s = ( - "_validated" - if n >= args.min_to_validate and n <= args.max_to_validate - else "" + records, threads = [], [] + + def threadable_fun(*args): + r = fun(*args) + if type(r) is not tuple: + r = (r,) + records.append(r) + + for args in arguments: + # To get a different sequence between threads + log_string(f"dummy_rand {torch.rand(1)}") + # torch.rand(1) + t = threading.Thread(target=threadable_fun, daemon=True, args=args) + threads.append(t) + t.start() + + for t in threads: + t.join() + + if records[0] == (None,): + return + else: + return [ + torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) + ] + + +###################################################################### + + +def save_models(models, suffix=""): + if suffix != "": + suffix = "_" + suffix + + for model in models: + filename = f"ae_{model.id:03d}{suffix}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + }, + os.path.join(args.result_dir, filename), ) - q = valid_c_quizzes( - quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n - )[:72] + log_string(f"wrote ae_*{suffix}.pth") - quiz_machine.reverse_random_half_in_place(q) - if q.size(0) > 0: - quiz_machine.save_quizzes( - args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q - ) +###################################################################### + + +def save_quiz_image(models, c_quizzes, filename, local_device=main_device): + c_quizzes = c_quizzes.to(local_device) + + nb_correct, nb_wrong = evaluate_quizzes( + quizzes=c_quizzes, + models=models, + with_hints=False, + local_device=local_device, + ) + + comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)] + + problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=c_quizzes, + comments=comments, + delta=True, + nrow=8, + ) + + log_string(f"wrote {filename}") +###################################################################### + +problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_world_tasks, +) + +if not args.resume: + problem.save_some_examples(args.result_dir) + + +log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") + +vocabulary_size = problem.vocabulary_size() + +log_string(f"vocabulary_size {vocabulary_size}") + ###################################################################### models = [] -for k in range(args.nb_gpts): - model = mygpt.MyGPT( - vocabulary_size=vocabulary_size, +if args.model_type == "standard": + model_constructor = attae.AttentionAE +elif args.model_type == "functional": + model_constructor = attae.FunctionalAttentionAE +else: + raise ValueError(f"Unknown model type {args.model_type}") + + +for i in range(args.nb_models): + model = model_constructor( + vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, dim_keys=args.dim_keys, dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - causal=True, dropout=args.dropout, - ).to(device) + ) + + # model = torch.compile(model) - model.main_test_accuracy = 0.0 - model.id = k + model.id = i + model.test_accuracy = 0.0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) models.append(model) +###################################################################### + +current_epoch = 0 -nb_parameters = sum(p.numel() for p in models[0].parameters()) -log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") +if args.resume: + for model in models: + filename = f"ae_{model.id:03d}.pth" + + d = torch.load( + os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, + ) + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.test_accuracy = d["test_accuracy"] + log_string(f"successfully loaded {filename}") + + filename = "state.pth" + state = torch.load( + os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, + ) + + log_string(f"successfully loaded {filename}") + + current_epoch = state["current_epoch"] + train_c_quizzes = state["train_c_quizzes"] + test_c_quizzes = state["test_c_quizzes"] ###################################################################### -nb_new_c_quizzes_for_train = args.nb_train_samples // 50 -nb_new_c_quizzes_for_test = args.nb_test_samples // 50 +nb_parameters = sum(p.numel() for p in models[0].parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") -log_string( - f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}" -) ###################################################################### -if args.dirty_debug: - args.accuracy_to_make_c_quizzes = 0.0 - args.nb_gpts = 2 - nb_new_c_quizzes_for_train = 100 - nb_new_c_quizzes_for_test = 10 +train_c_quizzes, test_c_quizzes = None, None ###################################################################### -for n_epoch in range(args.nb_epochs): - log_string(f"--- epoch {n_epoch} ----------------------------------------") +for n_epoch in range(current_epoch, args.nb_epochs): + start_time = time.perf_counter() - cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) - log_string(f"current_test_accuracies {cta}") + state = { + "current_epoch": n_epoch, + "train_c_quizzes": train_c_quizzes, + "test_c_quizzes": test_c_quizzes, + } - ################################################## - # Select, improve, and eval the worst model + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") - weakest_model = min(models, key=lambda m: float(m.main_test_accuracy)) + log_string(f"--- epoch {n_epoch} ----------------------------------------") - log_string( - f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}" - ) + cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) + log_string(f"current_test_accuracies {cta}") - one_epoch(weakest_model, quiz_machine) + # -------------------------------------------------------------------- - log_string( - f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + lowest_test_accuracy = min([float(m.test_accuracy) for m in models]) - run_tests(weakest_model, quiz_machine, deterministic_synthesis=False) + if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes: + if train_c_quizzes is None: + save_models(models, "naive") - log_string( - f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + nb_gpus = len(gpus) + nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - ################################################## - # Replace a fraction of the w_quizzes with fresh ones + (new_c_quizzes,) = multithread_execution( + generate_c_quizzes, + [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], + ) + + save_quiz_image( + models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png" + ) - quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) + log_string(f"generated_c_quizzes {new_c_quizzes.size()}") - ################################################## - # If all the models are good enough, generate new quizzes and - # re-compute the test errors + train_c_quizzes = ( + new_c_quizzes + if train_c_quizzes is None + else torch.cat([train_c_quizzes, new_c_quizzes]) + ) + train_c_quizzes = train_c_quizzes[-args.nb_train_samples :] - if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - create_c_quizzes( - models, - quiz_machine, - nb_for_train=nb_new_c_quizzes_for_train, - nb_for_test=nb_new_c_quizzes_for_test, + nb_correct, _ = evaluate_quizzes( + quizzes=train_c_quizzes, + models=models, + with_hints=False, + local_device=local_device, ) + test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] + for model in models: - run_tests(model, quiz_machine, deterministic_synthesis=False) + model.test_accuracy = 0 + if train_c_quizzes is None: + log_string("no_c_quiz") + else: + log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}") -###################################################################### + # -------------------------------------------------------------------- + + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) + weakest_models = ranked_models[: len(gpus)] + + log_string( + f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}" + ) + + multithread_execution( + one_complete_epoch, + [ + (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu) + for model, gpu in zip(weakest_models, gpus) + ], + ) + + save_models(models) + + # -------------------------------------------------------------------- + + duration = time.perf_counter() - start_time + str_duration = "" + if duration >= 60: + str_duration += f"{int(duration)//60}min" + str_duration += f"{int(duration)%60}s" + str_next = ( + datetime.datetime.now() + datetime.timedelta(seconds=duration) + ).strftime("%H:%M:%S") + log_string(f"epoch_duration {str_duration} next_finish {str_next}") diff --git a/mygpt.py b/mygpt.py deleted file mode 100755 index d0fda7e..0000000 --- a/mygpt.py +++ /dev/null @@ -1,339 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -# This is an implementation from scratch of a "GPT", that is a model -# composed of several causal self-attention blocks. It is equipped -# with a caching mechanism for keys and values to avoid a O(N^3) cost -# for auto-regression. - -import math - -import torch - -from torch import nn -from torch.nn import functional as F - -###################################################################### - -# A BracketedSequence is a BxTx... tensor with a first and a nb time -# steps to compute. - -# Modules able to process it expect that they will have to process a -# first bracket starting at t=0, followed by a succession of brackets -# that move forward in time, do not overlap, and cover the axis T with -# no holes. -# -# Although it is more general, for a classical prompt-conditioned -# auto-regressive process it will be a first bracket starting at 0 and -# of arbitrary length for the "prompt", followed by brackets of length -# 1 for the successive tokens. -# -# Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 - - -class BracketedSequence: - def __init__(self, x, first=None, nb=None): - self.x = x - self.first = 0 if first is None else first - self.nb = x.size(1) if nb is None else nb - - def slice(self): - return self.x[:, self.first : self.first + self.nb] - - def complete(self): - return self.first == 0 and self.nb == self.x.size(1) - - -###################################################################### - - -class CacheWrapper(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - if bs.first == 0: - y = self.f(bs.slice()) - self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:])) - self.cache_y[:, bs.first : bs.first + bs.nb] = y - else: - self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) - - return BracketedSequence(self.cache_y, bs.first, bs.nb) - - -############################## - - -class WithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb) - - -############################## - - -class AddPositionalEncoding(nn.Module): - def __init__(self, len_max): - super().__init__() - self.len_max = len_max - - # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - - def forward(self, bs): - if bs.first == 0: - t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ - :, None - ] - j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ - None, : - ] - k = j % 2 - self.pe = torch.sin( - t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k - ) - self.cache_y = bs.x.new(bs.x.size()) - - self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] - ) - - return BracketedSequence(self.cache_y, bs.first, bs.nb) - - -############################## - - -class QKVAttention(nn.Module): - def __init__( - self, - dim_in, - dim_qk, - dim_v, - nb_heads=1, - causal=False, - attention_dropout=0.0, - ): - super().__init__() - - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - - self.causal = causal - self.attention_dropout = attention_dropout - self.record_attention = False - - self.w_q = randw(nb_heads, dim_qk, dim_in) - self.w_k = randw(nb_heads, dim_qk, dim_in) - self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_v * nb_heads, dim_in) - - def forward(self, bs_q): - x_q = bs_q.x - - assert ( - self.causal or bs_q.complete() - ), "Partial evaluation is only possible for causal models" - - if bs_q.first == 0: - self.cache_k = x_q.new_zeros( - x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) - ) - self.cache_v = x_q.new_zeros( - x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1) - ) - self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) - - q = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q - ) - - self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k - ) - self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v - ) - - a = torch.einsum( - "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] - ) / math.sqrt(self.w_q.size(1)) - - if self.causal: - if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) - a = a.masked_fill( - self.cache_attzero[ - :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb - ], - float("-inf"), - ) - - a = a.softmax(dim=3) - - if self.record_attention: - self.a = a - - a = F.dropout(a, self.attention_dropout, self.training) - - y = torch.einsum( - "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb] - ).flatten(2) - - self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o - - return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb) - - -############################## - - -class NoiseInjector(nn.Module): - def __init__(self): - super().__init__() - self.noise_std = 0.0 - - def forward(self, x): - if self.noise_std > 0: - x = x + torch.randn(x.size(), device=x.device) * self.noise_std - return x - - -def set_noise_injection(model, noise_std): - for m in model.modules(): - if isinstance(m, NoiseInjector): - m.noise_std = noise_std - - -############################## - - -class MyGPT(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - causal=False, - dropout=0.0, - len_max=1e5, - ): - super().__init__() - - assert dim_model % nb_heads == 0 - - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), - ) - - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - NoiseInjector(), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - causal=causal, - attention_dropout=dropout, - ), - ), - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - NoiseInjector(), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), - ] - - self.trunk = nn.Sequential(*trunk_blocks) - - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) - - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) - - def forward(self, bs): - # print(f"GENERATE {bs.first} {bs.first+bs.nb}") - bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) - bs = self.embedding(bs) - bs = self.trunk(bs) - bs = self.readout(bs) - return bs - - def record_attention(self, v=True): - for m in self.modules(): - if isinstance(m, QKVAttention): - m.record_attention = v - - def retrieve_attention(self): - a = [] - for m in self.modules(): - if isinstance(m, QKVAttention): - a.append(m.a) - return a - - -###################################################################### - -if __name__ == "__main__": - print("Basic check.") - - vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) - - model = MyGPT( - vocabulary_size=vocabulary_size, - dim_model=4, - dim_keys=2, - dim_hidden=2, - nb_heads=2, - nb_blocks=2, - dropout=0.1, - causal=True, - ) - - model.eval() - y1 = model(BracketedSequence(x)).x - y2 = torch.randn_like(y1) - for s in range(x.size(1)): - z = model(BracketedSequence(x, s, 1)) - y2[:, s] = z.slice() - - print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") - -###################################################################### diff --git a/problem.py b/problem.py index a49634d..8c1db63 100755 --- a/problem.py +++ b/problem.py @@ -9,99 +9,88 @@ import threading, queue, torch, tqdm class Problem: - def nb_token_values(self): - pass - - def trivial_prompts_and_answers(self, prompts, answers): - pass - - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): - pass - - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - pass - - -class MultiThreadProblem: - def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1): - self.problem = problem - self.chunk_size = chunk_size - self.queue = queue.Queue(maxsize=max_nb_cached_chunks) - for _ in range(nb_threads): - threading.Thread(target=self.fill_cache, daemon=True).start() - self.rest = None - - def nb_token_values(self): - return self.problem.nb_token_values() + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.problem.save_quizzes( - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ) + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size def fill_cache(self): while True: - prompts, answers = self.problem.generate_prompts_and_answers( - self.chunk_size - ) + quizzes = self.generate_w_quizzes_(self.chunk_size) + self.queue.put(quizzes.to("cpu"), block=True) - self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) + def generate_w_quizzes(self, nb, progress_bar=True): + if self.queue is None: + return self.generate_w_quizzes_(nb) - def trivial_prompts_and_answers(self, prompts, answers): - return self.problem.trivial_prompts_and_answers(prompts, answers) - - def generate_prompts_and_answers(self, nb): if self.rest is not None: - prompts, answers = rest + quizzes = rest else: - prompts, answers = [], [] + quizzes = [] self.rest = None - n = sum([p.size(0) for p in prompts]) - - with tqdm.tqdm( - total=nb, - dynamic_ncols=True, - desc="world generation", - ) as pbar: + n = sum([q.size(0) for q in quizzes]) + + if progress_bar: + with tqdm.tqdm( + total=nb, dynamic_ncols=True, desc="world generation", delay=10 + ) as pbar: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) + else: while n < nb: - p, s = self.queue.get(block=True) - prompts.append(p) - answers.append(s) - n += p.size(0) - pbar.update(p.size(0)) + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) - prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) - assert n == prompts.size(0) + quizzes = torch.cat(quizzes, dim=0) + assert n == quizzes.size(0) k = n - nb if k > 0: - rest = (prompts[-k:], answers[-k:]) - prompts, answers = prompts[:-k], answers[:-k] + rest = quizzes[-k:] + quizzes = quizzes[:-k] + + return quizzes + + ###################################################################### + + def trivial_prompts_and_answers(self, prompts, answers): + pass + + # The one to implement, returns two tensors nb x D and nb x D' + def generate_w_quizzes_(self, nb): + pass + + # save a file to vizualize quizzes, you can save a txt or png file + def save_quiz_illustrations( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + pass + + def save_some_examples(self, result_dir): + pass - return prompts, answers + ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py deleted file mode 100755 index f0fb408..0000000 --- a/quiz_machine.py +++ /dev/null @@ -1,541 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -import mygpt -from mygpt import BracketedSequence - -###################################################################### - -# ar_mask is a tensor with 0s and 1s, of same shape as input, with -# 1s where tokens should be generated. The others are kept -# unchanged. - - -def one_batch_masked_inplace_autoregression( - model, - input, - ar_mask, - seq_logproba, - temperature, - deterministic_synthesis, -): - to_generate = (ar_mask.sum(0) > 0).nonzero() - - if to_generate.min() > 0: - model( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x - - logits = output[:, s] - - logits = (logits / temperature).log_softmax(dim=-1) - - if deterministic_synthesis: - t_next = logits.argmax(-1) - else: - dist = torch.distributions.categorical.Categorical(logits=logits) - t_next = dist.sample() - - all_n = torch.arange(t_next.size(0)) - - seq_logproba += logits[all_n, t_next] - - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] - - -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - seq_logproba, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc=None, - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(batch_size), - ar_mask.split(batch_size), - seq_logproba.split(batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - ) - - model.train(t) - - -###################################################################### - - -class QuizMachine: - def indices_forward_and_backward(self, quizzes): - i_forward = quizzes[:, 0] == self.token_forward - j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward - i_backward = quizzes[:, 0] == self.token_backward - j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward - assert torch.logical_or( - torch.logical_and(i_forward, j_forward), - torch.logical_and(i_backward, j_backward), - ).all() - return i_forward, i_backward - - def non_trivial(self, quizzes): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return torch.logical_not( - self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - ) - ) - - def reverse_time(self, quizzes): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - forward_to_backward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len], - quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1], - quizzes[:, 1 : 1 + self.prompt_len], - ], - dim=1, - ) - - forward_to_backward[:, 0] = self.token_backward - forward_to_backward[:, 1 + self.answer_len] = self.token_backward - - backward_to_forward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.answer_len :], - quizzes[:, 1 + self.answer_len : 2 + self.answer_len], - quizzes[:, 1 : 1 + self.answer_len], - ], - dim=1, - ) - - backward_to_forward[:, 0] = self.token_forward - backward_to_forward[:, 1 + self.prompt_len] = self.token_forward - - m = i_forward.long()[:, None] - - return m * forward_to_backward + (1 - m) * backward_to_forward - - def reverse_random_half_in_place(self, quizzes): - i = torch.rand(quizzes.size(0)) < 0.5 - if i.any(): - quizzes[i] = self.reverse_time(quizzes[i]) - - def make_ar_mask(self, quizzes, first=False): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - t = torch.arange(quizzes.size(1), device=quizzes.device) - - if first: - m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long() - m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long() - else: - m_forward = (t >= 2 + self.prompt_len).long() - m_backward = (t >= 2 + self.answer_len).long() - - m = i_forward.long()[:, None] - - return m * m_forward + (1 - m) * m_backward - - def generate_token_sequences(self, nb): - prompts, answers = self.problem.generate_prompts_and_answers(nb) - - if self.prompt_len is None: - self.prompt_len = prompts.size(1) - - if self.answer_len is None: - self.answer_len = answers.size(1) - - assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len - - result = [] - - for prompt, answer in zip(prompts, answers): - a = [ - torch.tensor([self.token_forward]), - prompt, - torch.tensor([self.token_forward]), - answer, - ] - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - - def __init__( - self, - problem, - nb_train_samples, - nb_test_samples, - back_accuracy, - batch_size, - result_dir, - logger, - device=torch.device("cpu"), - ): - super().__init__() - - v = problem.nb_token_values() - self.token_forward = v - self.token_backward = v + 1 - self.nb_token_values = v + 2 - - self.problem = problem - self.back_accuracy = back_accuracy - self.batch_size = batch_size - self.device = device - self.logger = logger - self.prompt_len = None - self.answer_len = None - - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - - self.train_c_quizzes = [] - self.test_c_quizzes = [] - - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) - - def save_quizzes( - self, - result_dir, - filename_prefix, - quizzes, - mistakes=None, - ): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - assert n_forward.size(0) + backward.size(0) == quizzes.size(0) - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - - predicted_prompts = n_backward.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes - else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 - - self.problem.save_quizzes( - result_dir, - filename_prefix, - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - predicted_prompts, - predicted_answers, - ) - - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes - - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] - - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] - - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) - - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 - - # Shuffle - input = input[torch.randperm(input.size(0))] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch - - def vocabulary_size(self): - return self.nb_token_values - - def produce_results( - self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, log_prefix=None): - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - - n_forward = input[:, 0] == self.token_forward - n_backward = input[:, 0] == self.token_backward - - correct[n_forward] = ( - (input[n_forward] == result[n_forward]).long().min(dim=1).values - ) - - if self.back_accuracy and n_backward.any(): - # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.reverse_time(result[n_backward]) - back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len - ] - _, correct[n_backward] = compute_accuracy(back_input) - - if log_prefix is not None: - forward_nb_correct = correct[n_forward].sum() - forward_nb_total = correct[n_forward].size(0) - backward_nb_correct = correct[n_backward].sum() - backward_nb_total = correct[n_backward].size(0) - - self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" - ) - - return result, correct - - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") - - test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" - ) - - main_test_accuracy = test_correct.sum() / test_correct.size(0) - self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") - - ############################## - - self.save_quizzes( - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:72], - mistakes=test_correct[:72] * 2 - 1, - ) - - return main_test_accuracy - - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - fresh_w_quizzes = self.generate_token_sequences(nb) - self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) - - def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) - - def compute_correctness( - self, - c_quizzes, - models_for_validation, - bidirectional_validation=False, - deterministic_validation=True, - ): - if bidirectional_validation: - backward_c_quizzes = self.forward_to_backward(c_quizzes) - - seq_logproba = torch.zeros( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=self.device, - ) - - nb_correct = 0 - - seq_logproba[...] = 0.0 - - for model in models_for_validation: - result = c_quizzes.clone() - - ar_mask = self.make_ar_mask(result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving c_quizzes", - device=self.device, - ) - - correct = (c_quizzes == result).long().min(dim=-1).values - - if bidirectional_validation: - backward_result = backward_c_quizzes.clone() - - ar_mask = self.make_ar_mask(backward_result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=backward_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving backward c_quizzes", - device=self.device, - ) - - backward_correct = ( - (backward_c_quizzes == backward_result).long().min(dim=-1).values - ) - - correct *= backward_correct - - # endif - - nb_correct += correct - - return nb_correct, seq_logproba - - ############################################################### - - def generate_quizzes(self, nb, model_for_generation, temperature=1.0): - c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - # First, we generate the answer at high temperature - - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - device=self.device, - ) - - # Then, we generate the prompt at low temperature - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) - - # Then we return the quizz, and re-generate the response, now - # at low temperature - - c_quizzes = self.reverse_time(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) - - return c_quizzes diff --git a/sky.py b/sky.py deleted file mode 100755 index ed440d3..0000000 --- a/sky.py +++ /dev/null @@ -1,360 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, sys, tqdm, os, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -###################################################################### - -import problem - - -class Sky(problem.Problem): - colors = torch.tensor( - [ - [255, 255, 255], - [255, 0, 0], - [0, 192, 0], - [0, 0, 255], - [255, 192, 0], - [0, 255, 255], - [255, 0, 255], - [192, 255, 192], - [255, 192, 192], - [192, 192, 255], - [192, 192, 192], - ] - ) - - token_background = 0 - first_bird_token = 1 - nb_bird_tokens = colors.size(0) - 1 - - token2char = ( - "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" - ) - - def __init__( - self, - height=6, - width=8, - nb_birds=3, - speed=2, - nb_iterations=2, - avoid_collision=True, - ): - self.height = height - self.width = width - self.nb_birds = nb_birds - self.speed = speed - self.nb_iterations = nb_iterations - self.avoid_collision = avoid_collision - - def generate_frame_sequences(self, nb): - frame_sequences = [] - - for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"): - i, j, vi, vj = ( - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - ) - - def collision_okay(): - if not self.avoid_collision: - return True - - count = torch.zeros(self.height, self.width, dtype=torch.int64) - - for n in range(self.nb_birds): - count[i[n], j[n]] += 1 - count[i[n] - vi[n], j[n]] += 1 - count[i[n], j[n] - vj[n]] += 1 - - return count.max() <= 1 - - col = ( - torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values - + 1 - ) - - while True: - while True: - for n in range(self.nb_birds): - while True: - i[n] = torch.randint(self.height, (1,)) - j[n] = torch.randint(self.width, (1,)) - vm = torch.randint(4, (1,)) - vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1 - if ( - i[n] - vi[n] >= 0 - and i[n] - vi[n] < self.height - and j[n] - vj[n] >= 0 - and j[n] - vj[n] < self.width - ): - break - - if collision_okay(): - break - - result = torch.zeros( - self.nb_iterations * self.speed, - self.height, - self.width, - dtype=torch.int64, - ) - - fine = torch.empty(self.nb_iterations * self.speed) - - t_to_keep = ( - torch.arange(self.nb_iterations, device=result.device) * self.speed - ) - - for l in range(self.nb_iterations * self.speed): - fine[l] = collision_okay() - for n in range(self.nb_birds): - c = col[n] - result[l, i[n], j[n]] = c - result[l, i[n] - vi[n], j[n]] = c - result[l, i[n], j[n] - vj[n]] = c - - if (i[n] == 0 and vi[n] == -1) or ( - i[n] == self.height - 1 and vi[n] == 1 - ): - vi[n] = -vi[n] - - if (j[n] == 0 and vj[n] == -1) or ( - j[n] == self.width - 1 and vj[n] == 1 - ): - vj[n] = -vj[n] - - i[n] += vi[n] - j[n] += vj[n] - - result = result[t_to_keep] - fine = fine[t_to_keep] - - if fine[-1]: - break - - frame_sequences.append(result) - - return frame_sequences - - ###################################################################### - - def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) - m = torch.logical_and( - x >= 0, x < self.first_bird_token + self.nb_bird_tokens - ).long() - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] - - for n in range(m.size(0)): - for i in range(m.size(1)): - for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 - - return x - - def seq2str(self, seq): - result = [] - for s in seq: - result.append("".join([self.token2char[v] for v in s])) - return result - - def save_image( - self, - result_dir, - filename, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - if predicted_prompts is None: - predicted_prompts = 255 - - if predicted_answers is None: - predicted_answers = 255 - - def add_frame(x, c, margin, bottom=False): - if bottom: - h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 - else: - h, w, di, dj = ( - x.size(2) + 2 * margin, - x.size(3) + 2 * margin, - margin, - margin, - ) - - y = x.new_full((x.size(0), x.size(1), h, w), 0) - - if type(c) is int: - y[...] = c - else: - c = c.long()[:, None] - c = ( - (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) - + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) - + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) - ) - y[...] = c[:, :, None, None] - - y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x - - return y - - margin = 4 - - img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) - h = img_prompts.size(2) - img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) - - img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) - img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) - - img_prompts = add_frame( - img_prompts, c=predicted_prompts, margin=margin, bottom=True - ) - img_answers = add_frame( - img_answers, c=predicted_answers, margin=margin, bottom=True - ) - - marker_size = 16 - - separator = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - marker_size, - ), - 255, - ) - - separator[:, :, 0] = 0 - separator[:, :, h - 1] = 0 - - for k in range(1, 2 * marker_size - 8): - i = k - (marker_size - 4) - j = marker_size - 5 - abs(i) - separator[:, :, h // 2 - 1 + i, 2 + j] = 0 - separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 - - img = torch.cat([img_prompts, separator, img_answers], dim=3) - - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 - ) - - ###################################################################### - - def nb_token_values(self): - return len(self.colors) - - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - - # warnings.warn("dirty test with longer answer", RuntimeWarning) - # answers = torch.cat( - # [ - # frame_sequences[:, frame_sequences.size(1) // 2 :], - # frame_sequences[:, frame_sequences.size(1) // 2 :], - # ], - # dim=3, - # ).flatten(1) - - return prompts, answers - - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.save_image( - result_dir, - filename_prefix + ".png", - prompts, - answers, - predicted_prompts, - predicted_answers, - ) - - -###################################################################### - -if __name__ == "__main__": - import time - - sky = Sky(height=6, width=8, speed=1, nb_iterations=4) - - prompts, answers = sky.generate_prompts_and_answers(4) - - predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 - predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - - sky.save_quizzes( - "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers - ) - - # start_time = time.perf_counter() - # token_sequences = sky.generate_token_sequences(nb=64) - # delay = time.perf_counter() - start_time - # print(f"{token_sequences.size(0)/delay:02f} seq/s") - - # print(sky.seq2str(seq[:4])) - - # for t in range(len(it[0])): - # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0) - # torchvision.utils.save_image( - # img.float() / 255.0, - # f"/tmp/frame_{t:03d}.png", - # nrow=8, - # padding=6, - # pad_value=0, - # ) - - # m = (torch.rand(seq.size()) < 0.05).long() - # seq = (1 - m) * seq + m * 23 - - # print(seq.size()) - # img = sky.seq2img(token_sequences) - # print(img.size()) - - # torchvision.utils.save_image( - # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0 - # ) diff --git a/tasks.py b/tasks.py deleted file mode 100755 index 80ffdbb..0000000 --- a/tasks.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -from mygpt import BracketedSequence - -###################################################################### - - -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - summed_logits, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc="autoregression", - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip(input.split(batch_size), ar_mask.split(batch_size)) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask in batches: - model.masked_inplace_autoregression( - input=input, - ar_mask=ar_mask, - summed_logits=summed_logits, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - forbidden_tokens=forbidden_tokens, - forced_biases=logit_biases, - ) - - model.train(t) - - -###################################################################### - - -class Task: - def batches(self, split="train", nb_to_use=-1, desc=None): - pass - - def vocabulary_size(self): - pass - - def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis - ): - pass - - -###################################################################### - -import world - - -class World(Task): - def save_image(self, input, result_dir, filename, logger): - img = world.seq2img(input.to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - - def make_ar_mask(self, input): - b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 - return b.long()[None, :].expand_as(input) - - def __init__( - self, - nb_train_samples, - nb_test_samples, - batch_size, - result_dir=None, - logger=None, - device=torch.device("cpu"), - ): - super().__init__() - - self.batch_size = batch_size - self.device = device - self.height = 6 - self.width = 8 - - self.train_input = world.generate_seq( - nb_train_samples, height=self.height, width=self.width - ).to(device) - - self.test_input = world.generate_seq( - nb_test_samples, height=self.height, width=self.width - ).to(device) - - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - - self.train_quizzes = [] - self.test_quizzes = [] - - if result_dir is not None: - self.save_image( - self.train_input[:72], result_dir, f"world_train.png", logger - ) - - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - input = self.train_input - quizzes = self.train_quizzes - else: - input = self.test_input - quizzes = self.test_quizzes - - if len(quizzes) > 0: - quizzes = torch.cat(quizzes, dim=0) - if quizzes.size(0) > input.size(0) // 2: - i = torch.randperm(input.size(0))[: input.size(0) // 2] - quizzes = quizzes[i] - - i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)] - input = input[i] - - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = quizzes.size(0) - - input = torch.cat([input, quizzes], dim=0) - else: - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = 0 - - # Shuffle - input = input[torch.randperm(input.size(0))] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch - - def vocabulary_size(self): - return self.nb_codes - - def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, logger=None): - input = input[:nmax] - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - nb_total, nb_correct = ( - input.size(0), - (input == result).long().min(dim=1).values.sum(), - ) - - return nb_total, nb_correct - - train_nb_total, train_nb_correct = compute_accuracy(self.train_input) - - logger( - f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - ) - - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger) - - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) - - main_test_accuracy = test_nb_correct / test_nb_total - logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") - - ############################## - - input = self.test_input[:96] - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - self.save_image( - result[:72], - result_dir, - f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", - logger, - ) - - return main_test_accuracy - - def renew_samples(self, nb, for_train=True): - input = self.train_input if for_train else self.test_input - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( - self.device - ) - - def store_new_quizzes(self, new_quizzes, for_train=True): - if for_train: - self.train_quizzes.append(new_quizzes) - else: - self.test_quizzes.append(new_quizzes) - - def create_new_quizzes( - self, - n_epoch, - result_dir, - logger, - nb, - model, - other_models, - desired_average_logits=None, - ): - ############################################################### - # Generate quizzes with model - - quizzes = torch.empty( - nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 - ) - - ar_mask = torch.full(quizzes.size(), 1, device=self.device) - summed_logits = torch.empty(nb, device=self.device) - - temperature = 1 - d_temperature = 1 - - while True: - summed_logits[...] = 0 - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=quizzes, - ar_mask=ar_mask, - summed_logits=summed_logits, - temperature=temperature, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) - - average_logits = summed_logits.mean() - - logger(f"{average_logits=} {desired_average_logits=}") - - if desired_average_logits is None: - break - - # Oh man that's ugly - if average_logits < desired_average_logits * 1.1: - if d_temperature > 0: - d_temperature *= -0.5 - temperature += d_temperature - elif average_logits > desired_average_logits: - if d_temperature < 0: - d_temperature *= -0.5 - temperature += d_temperature - else: - break - - logger(f"changing temperature to {temperature}") - - ############################################################### - # Create the reverse quizzes - - l = self.height * self.width - direction = quizzes[:, l : l + 1] - direction = world.token_forward * ( - direction == world.token_backward - ) + world.token_backward * (direction == world.token_forward) - reverse_quizzes = torch.cat( - [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 - ) - - ar_mask = self.make_ar_mask(quizzes) - - ############################################################### - # Check how many of the other models can solve them in both - # directions - - nb_correct = [] - - for m in other_models: - result = quizzes.clone() - - masked_inplace_autoregression( - model=m, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=True, - progress_bar_desc="solving quizzes", - device=self.device, - ) - - correct = (quizzes == result).long().min(dim=-1).values - - reverse_result = reverse_quizzes.clone() - - masked_inplace_autoregression( - model=m, - batch_size=self.batch_size, - input=reverse_result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=True, - progress_bar_desc="solving reversed quizzes", - device=self.device, - ) - - reverse_correct = ( - (reverse_quizzes == reverse_result).long().min(dim=-1).values - ) - - nb_correct.append((correct * reverse_correct)[None, :]) - - nb_correct = torch.cat(nb_correct, dim=0) - - # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - # with open(filename, "w") as f: - # for k in nb_correct: - # f.write(f"{k}\n") - - return quizzes, nb_correct.sum(dim=0), summed_logits.mean() diff --git a/wireworld.py b/wireworld.py deleted file mode 100755 index 8257cad..0000000 --- a/wireworld.py +++ /dev/null @@ -1,357 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, sys, tqdm, os - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -###################################################################### - -import problem - - -class Wireworld(problem.Problem): - colors = torch.tensor( - [ - [128, 128, 128], - [128, 128, 255], - [255, 0, 0], - [255, 255, 0], - ] - ) - - token_empty = 0 - token_head = 1 - token_tail = 2 - token_conductor = 3 - token_forward = 4 - token_backward = 5 - - token2char = ( - "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" - ) - - def __init__( - self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4 - ): - self.height = height - self.width = width - self.nb_objects = nb_objects - self.nb_walls = nb_walls - self.speed = speed - self.nb_iterations = nb_iterations - - def direction_tokens(self): - return self.token_forward, self.token_backward - - def generate_frame_sequences(self, nb): - result = [] - N = 100 - for _ in tqdm.tqdm( - range(0, nb + N, N), dynamic_ncols=True, desc="world generation" - ): - result.append(self.generate_frame_sequences_hard(100)) - return torch.cat(result, dim=0)[:nb] - - def generate_frame_sequences_hard(self, nb): - frame_sequences = [] - nb_frames = (self.nb_iterations - 1) * self.speed + 1 - - result = torch.full( - (nb * 4, nb_frames, self.height, self.width), - self.token_empty, - ) - - for n in range(result.size(0)): - while True: - i = torch.randint(self.height, (1,)) - j = torch.randint(self.width, (1,)) - v = torch.randint(2, (2,)) - vi = v[0] * (v[1] * 2 - 1) - vj = (1 - v[0]) * (v[1] * 2 - 1) - while True: - if i < 0 or i >= self.height or j < 0 or j >= self.width: - break - o = 0 - if i > 0: - o += (result[n, 0, i - 1, j] == self.token_conductor).long() - if i < self.height - 1: - o += (result[n, 0, i + 1, j] == self.token_conductor).long() - if j > 0: - o += (result[n, 0, i, j - 1] == self.token_conductor).long() - if j < self.width - 1: - o += (result[n, 0, i, j + 1] == self.token_conductor).long() - if o > 1: - break - result[n, 0, i, j] = self.token_conductor - i += vi - j += vj - if ( - result[n, 0] == self.token_conductor - ).long().sum() > self.width and torch.rand(1) < 0.5: - break - - while True: - for _ in range(self.height * self.width): - i = torch.randint(self.height, (1,)) - j = torch.randint(self.width, (1,)) - v = torch.randint(2, (2,)) - vi = v[0] * (v[1] * 2 - 1) - vj = (1 - v[0]) * (v[1] * 2 - 1) - if ( - i + vi >= 0 - and i + vi < self.height - and j + vj >= 0 - and j + vj < self.width - and result[n, 0, i, j] == self.token_conductor - and result[n, 0, i + vi, j + vj] == self.token_conductor - ): - result[n, 0, i, j] = self.token_head - result[n, 0, i + vi, j + vj] = self.token_tail - break - - # if torch.rand(1) < 0.75: - break - - weight = torch.full((1, 1, 3, 3), 1.0) - - mask = (torch.rand(result[:, 0].size()) < 0.01).long() - rand = torch.randint(4, mask.size()) - result[:, 0] = mask * rand + (1 - mask) * result[:, 0] - - # empty->empty - # head->tail - # tail->conductor - # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - - nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) - valid = nb_heads > 0 - - for l in range(nb_frames - 1): - nb_head_neighbors = ( - F.conv2d( - input=(result[:, l] == self.token_head).float()[:, None, :, :], - weight=weight, - padding=1, - ) - .long() - .squeeze(1) - ) - mask_1_or_2_heads = (nb_head_neighbors == 1).long() + ( - nb_head_neighbors == 2 - ).long() - result[:, l + 1] = ( - (result[:, l] == self.token_empty).long() * self.token_empty - + (result[:, l] == self.token_head).long() * self.token_tail - + (result[:, l] == self.token_tail).long() * self.token_conductor - + (result[:, l] == self.token_conductor).long() - * ( - mask_1_or_2_heads * self.token_head - + (1 - mask_1_or_2_heads) * self.token_conductor - ) - ) - pred_nb_heads = nb_heads - nb_heads = ( - (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) - ) - valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) - - result = result[valid] - - result = result[ - :, torch.arange(self.nb_iterations, device=result.device) * self.speed - ] - - i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 - result = result[i] - - # print(f"{result.size(0)=} {nb=}") - - if result.size(0) < nb: - # print(result.size(0)) - result = torch.cat( - [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 - ) - - return result[:nb] - - def generate_token_sequences(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - - result = [] - - for frame_sequence in frame_sequences: - a = [] - if torch.rand(1) < 0.5: - for frame in frame_sequence: - if len(a) > 0: - a.append(torch.tensor([self.token_forward])) - a.append(frame.flatten()) - else: - for frame in reversed(frame_sequence): - if len(a) > 0: - a.append(torch.tensor([self.token_backward])) - a.append(frame.flatten()) - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - - ###################################################################### - - def frame2img(self, x, scale=15): - x = x.reshape(-1, self.height, self.width) - m = torch.logical_and(x >= 0, x < 4).long() - - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] - - for n in range(m.size(0)): - for i in range(m.size(1)): - for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 - - return x - - def seq2img(self, seq, scale=15): - all = [ - self.frame2img( - seq[:, : self.height * self.width].reshape(-1, self.height, self.width), - scale, - ) - ] - - separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0) - - t = self.height * self.width - - while t < seq.size(1): - direction_tokens = seq[:, t] - t += 1 - - direction_images = self.colors[ - torch.full( - (direction_tokens.size(0), self.height * scale - 1, scale), 0 - ) - ].permute(0, 3, 1, 2) - - for n in range(direction_tokens.size(0)): - if direction_tokens[n] == self.token_forward: - for k in range(scale): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - 3 + scale // 2 - abs(k - scale // 2), - ] = 0 - elif direction_tokens[n] == self.token_backward: - for k in range(scale): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - 3 + abs(k - scale // 2), - ] = 0 - else: - for k in range(2, scale - 2): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - k, - ] = 0 - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - scale - 1 - k, - ] = 0 - - all += [ - separator, - direction_images, - separator, - self.frame2img( - seq[:, t : t + self.height * self.width].reshape( - -1, self.height, self.width - ), - scale, - ), - ] - - t += self.height * self.width - - return torch.cat(all, dim=3) - - def seq2str(self, seq): - result = [] - for s in seq: - result.append("".join([self.token2char[v] for v in s])) - return result - - def save_image(self, input, result_dir, filename): - img = self.seq2img(input.to("cpu")) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - - def save_quizzes(self, input, result_dir, filename_prefix): - self.save_image(input, result_dir, filename_prefix + ".png") - - -###################################################################### - -if __name__ == "__main__": - import time - - wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1) - - start_time = time.perf_counter() - frame_sequences = wireworld.generate_frame_sequences(nb=96) - delay = time.perf_counter() - start_time - print(f"{frame_sequences.size(0)/delay:02f} seq/s") - - # print(wireworld.seq2str(seq[:4])) - - for t in range(frame_sequences.size(1)): - img = wireworld.seq2img(frame_sequences[:, t]) - torchvision.utils.save_image( - img.float() / 255.0, - f"/tmp/frame_{t:03d}.png", - nrow=8, - padding=6, - pad_value=0, - ) - - # m = (torch.rand(seq.size()) < 0.05).long() - # seq = (1 - m) * seq + m * 23 - - wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) - token_sequences = wireworld.generate_token_sequences(32) - wireworld.save_quizzes(token_sequences, "/tmp", "seq") - # img = wireworld.seq2img(frame_sequences[:60]) - - # torchvision.utils.save_image( - # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1 - # )