--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ def forward(self, x):
+ t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+ j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+ k = j % 2 # works with float, weird
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+ y = x + pe
+ return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, x):
+ return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ attention=vanilla_attention,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.attention = attention
+ self.attention_dropout = attention_dropout
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(nb_heads, dim_v, dim_model)
+
+ def forward(self, x_q, x_kv=None):
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+ v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+ y = self.attention(q, k, v)
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+ def __init__(self, masker, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.masker = masker
+ self.mask = None
+
+ def forward(self, x):
+ if self.mask is None:
+ self.mask = self.masker(x)
+ return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ nb_work_tokens=100,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.nb_work_tokens = nb_work_tokens
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ def no_peek_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ n = self.nb_work_tokens
+ s = (q.size(2) - n) // 2
+ a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+ a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+ def masker(x):
+ m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+ return m[None, :, None]
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=no_peek_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+if __name__ == "__main__":
+ model = FunctionalAttentionAE(
+ vocabulary_size=100,
+ dim_model=16,
+ dim_keys=64,
+ dim_hidden=32,
+ nb_heads=4,
+ nb_work_tokens=10,
+ nb_blocks=4,
+ dropout=0.1,
+ )
+
+ x = torch.randint(100, (10, 50))
+ y = model(x)
+
+ with torch.no_grad():
+ model.eval()
+ x = torch.randint(100, (10, 50))
+ y = model(x)
+
+ print(y)
# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo, re
import torch, torchvision
######################################################################
+
+def text_img(height, width, text):
+ pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+ surface = cairo.ImageSurface.create_for_data(
+ pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+ )
+
+ ctx = cairo.Context(surface)
+ ctx.set_source_rgb(0, 0, 0)
+ ctx.set_font_size(16)
+ ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+ y = None
+ for line in text.split("\n"):
+ xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+ if y is None:
+ y = height * 1.5
+ x = height * 0.5
+
+ ctx.move_to(x, y)
+ ctx.show_text(line)
+ y += height * 1.5
+
+ ctx.stroke()
+
+ return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
import problem
+def grow_islands(nb, height, width, nb_seeds, nb_iterations):
+ w = torch.empty(5, 1, 3, 3)
+
+ w[0, 0] = torch.tensor(
+ [
+ [1.0, 1.0, 1.0],
+ [1.0, 0.0, 1.0],
+ [1.0, 1.0, 1.0],
+ ]
+ )
+
+ w[1, 0] = torch.tensor(
+ [
+ [-1.0, 1.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ ]
+ )
+
+ w[2, 0] = torch.tensor(
+ [
+ [0.0, 1.0, -1.0],
+ [0.0, 0.0, 1.0],
+ [0.0, 0.0, 0.0],
+ ]
+ )
+
+ w[3, 0] = torch.tensor(
+ [
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0],
+ [0.0, 1.0, -1.0],
+ ]
+ )
+
+ w[4, 0] = torch.tensor(
+ [
+ [0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [-1.0, 1.0, 0.0],
+ ]
+ )
+
+ Z = torch.zeros(nb, height, width)
+ U = Z.flatten(1)
+
+ for _ in range(nb_seeds):
+ M = F.conv2d(Z[:, None, :, :], w, padding=1)
+ M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+ M = ((M[:, 0] == 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+ M = M * torch.rand(M.size())
+ M = M.flatten(1)
+ M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+ U += M * Q
+
+ for _ in range(nb_iterations):
+ M = F.conv2d(Z[:, None, :, :], w, padding=1)
+ M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+ M = ((M[:, 1] >= 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+ M = M * torch.rand(M.size())
+ M = M.flatten(1)
+ M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+ U = Z.flatten(1)
+ U += M * Q
+
+ M = Z.clone()
+ Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
+
+ while True:
+ W = Z.clone()
+ Z = F.max_pool2d(Z, 3, 1, 1) * M
+ if Z.equal(W):
+ break
+
+ Z = Z.long()
+ U = Z.flatten(1)
+ V = F.one_hot(U).max(dim=1).values
+ W = V.cumsum(dim=1) - V
+ N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
+ Z = W[N, Z]
+
+ return Z
+
+
class Grids(problem.Problem):
+ # grid_gray = 64
+ # thickness = 1
+ # background_gray = 255
+ # dots = False
+
+ grid_gray = 240
+ thickness = 0
+ background_gray = 240
+ dots = False
+
+ # grid_gray = 192
+ # thickness = 0
+ # background_gray = 255
+ # dots = True
+
named_colors = [
- ("white", [255, 255, 255]),
+ ("white", [background_gray, background_gray, background_gray]),
+ # ("white", [224, 224, 224]),
("red", [255, 0, 0]),
- ("green", [0, 192, 0]),
+ ("green", [0, 160, 0]),
("blue", [0, 0, 255]),
("yellow", [255, 224, 0]),
("cyan", [0, 255, 255]),
("violet", [224, 128, 255]),
- ("lightgreen", [192, 255, 192]),
+ ("lightgreen", [160, 255, 160]),
("brown", [165, 42, 42]),
("lightblue", [192, 192, 255]),
("gray", [128, 128, 128]),
]
- def __init__(self, device=torch.device("cpu")):
+ def pure_noise(self, nb, device):
+ result = torch.randint(
+ self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
+ )
+ return result
+
+ def trivial(self, quizzes):
+ S = self.height * self.width
+ assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
+ a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+ dim=1
+ ).values
+
+ def text2quiz(self, t):
+ chr2col = [
+ (".", "white"),
+ ("r", "red"),
+ ("g", "green"),
+ ("b", "blue"),
+ ("y", "yellow"),
+ ("c", "cyan"),
+ ("v", "violet"),
+ ("l", "lightgreen"),
+ ("o", "brown"),
+ ("l", "lightblue"),
+ ("a", "gray"),
+ ]
+
+ col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+ chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+ t = re.sub(r"#.*\n", "", t).strip()
+ l = t.replace("\n\n", ";").split(";")
+
+ result = []
+
+ for t in l:
+ t = "".join(t.replace("\n", " ").strip().split(" "))
+ t = torch.tensor([chr2tok[c] for c in t])
+ t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+ t = torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.token_A],
+ [self.token_f_A],
+ [self.token_B],
+ [self.token_f_B],
+ ]
+ ),
+ t,
+ ],
+ dim=1,
+ )
+ result.append(t.flatten()[None, :])
+
+ return torch.cat(result, dim=0)
+
+ def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+ return (
+ (q[:, 0, 0] == self.l2tok[quad_order[0]])
+ & (q[:, 1, 0] == self.l2tok[quad_order[1]])
+ & (q[:, 2, 0] == self.l2tok[quad_order[2]])
+ & (q[:, 3, 0] == self.l2tok[quad_order[3]])
+ )
+
+ def __init__(
+ self,
+ max_nb_cached_chunks=None,
+ chunk_size=None,
+ nb_threads=-1,
+ tasks=None,
+ ):
self.colors = torch.tensor([c for _, c in self.named_colors])
+
+ self.nb_colors = len(self.colors)
+
+ self.nb_rec_max = 5
+ self.rfree = torch.tensor([])
+
self.height = 10
self.width = 10
- self.device = device
+ self.seq_len = 4 * self.height * self.width
- ######################################################################
+ self.cache_rec_coo = {}
+
+ all_tasks = [
+ ############################################ fundamental ones
+ self.task_replace_color,
+ self.task_translate,
+ self.task_grow,
+ self.task_frame,
+ ############################################
+ ############################################
+ self.task_half_fill,
+ self.task_detect,
+ self.task_scale,
+ self.task_symbols,
+ self.task_corners,
+ self.task_contact,
+ self.task_path,
+ self.task_fill,
+ ############################################ hard ones
+ self.task_isometry,
+ self.task_trajectory,
+ self.task_bounce,
+ # self.task_count, # NOT REVERSIBLE
+ # self.task_islands, # TOO MESSY
+ ]
+
+ if tasks is None:
+ self.all_tasks = all_tasks
+ else:
+ self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
- def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
- m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+ super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
+ ######################################################################
+
+ def vocabulary_size(self):
+ # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+ # return self.nb_colors+4
+ return self.nb_colors
+
+ def grid2img(self, x, scale=15, grids=True):
+ m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+ y = self.colors[x * m].permute(0, 3, 1, 2)
+ s = y.shape
+ y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+ if grids:
+ for t in range(self.thickness):
+ y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+ y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+ if self.dots:
+ z = y.reshape(
+ y.size(0),
+ y.size(1),
+ y.size(2) // scale,
+ scale,
+ y.size(3) // scale,
+ scale,
+ )
+ z = z[
+ :,
+ :,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ ]
+ zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+ z[...] = zz * self.grid_gray + (zz == False) * z
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
+ if x[n, i, j] >= self.nb_colors:
+ # for k in range(3, scale - 2):
+ c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+ # y[n, :, i * scale + k, j * scale + k] = c
+ # y[n, :, i * scale + k, j * scale + scale - k] = c
+ y[
+ n,
+ :,
+ i * scale + 3 : i * scale + scale - 2,
+ j * scale + 3 : j * scale + scale - 2,
+ ] = c
+
+ y = y[:, :, 1:, 1:]
+
+ return y
+
+ def add_frame(self, img, colors, thickness):
+ if thickness > 0:
+ result = img.new(
+ img.size(0),
+ img.size(1),
+ img.size(2) + 2 * thickness,
+ img.size(3) + 2 * thickness,
+ )
+
+ result[...] = colors[:, :, None, None]
+ result[:, :, thickness:-thickness, thickness:-thickness] = img
+ else:
+ result = img
- return x
+ return result
- def save_image(
+ def save_quizzes_as_image(
self,
result_dir,
filename,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
+ quizzes,
+ predicted_parts=None,
+ correct_parts=None,
+ comments=None,
+ comment_height=48,
nrow=4,
- margin=8,
+ grids=True,
+ margin=12,
+ delta=False,
+ delta_highlight=False,
):
- S = self.height * self.width
- As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
- f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
- -1, self.height, self.width
- )
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
- prompts = torch.cat([As, f_As, Bs], dim=2)
- answers = answers.reshape(answers.size(0), self.height, self.width)
-
- if predicted_prompts is None:
- predicted_prompts = 255
+ quizzes = quizzes.to("cpu")
- if predicted_answers is None:
- predicted_answers = 255
-
- def add_frame(x, c, margin, bottom=False):
- if bottom:
- h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
- else:
- h, w, di, dj = (
- x.size(2) + 2 * margin,
- x.size(3) + 2 * margin,
- margin,
- margin,
- )
-
- y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
- if type(c) is int:
- y[...] = c
- else:
- c = c.long()[:, None]
- c = (
- (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
- * torch.tensor([64, 64, 64], device=c.device)
- + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
- + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
- + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
- )
- y[...] = c[:, :, None, None]
-
- y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+ S = self.height * self.width
- return y
+ A, f_A, B, f_B = (
+ quizzes.reshape(quizzes.size(0), 4, S)
+ .reshape(quizzes.size(0), 4, self.height, self.width)
+ .permute(1, 0, 2, 3)
+ )
- img_prompts = torch.cat(
+ frame, white, gray, green, red = torch.tensor(
[
- add_frame(
- add_frame(self.frame2img(x), c=0, margin=1),
- c=predicted_prompts,
- margin=margin,
- )
- for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+ [self.grid_gray, self.grid_gray, self.grid_gray],
+ [255, 255, 255],
+ [200, 200, 200],
+ [0, 255, 0],
+ [255, 0, 0],
],
- dim=3,
+ device=quizzes.device,
)
- h = img_prompts.size(2)
- img_answers = add_frame(
- add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
- c=predicted_answers,
- margin=margin,
- )
+ thickness = self.thickness
- separator_size = 2 * margin
+ if delta:
+ u = (A != f_A).long()
+ img_delta_A = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_A
+ )
+ u = (B != f_B).long()
+ img_delta_B = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_B
+ )
- separator = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
+ img_A = self.add_frame(
+ self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
)
-
- marker = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
+ img_f_A = self.add_frame(
+ self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_B = self.add_frame(
+ self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_f_B = self.add_frame(
+ self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
)
- # marker[:, :, 0] = 0
- # marker[:, :, h - 1] = 0
+ if delta_highlight:
+ q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+ img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
+
+ # predicted_parts Nx4
+ # correct_parts Nx4
+
+ if predicted_parts is None:
+ colors = white[None, None, :].expand(-1, 4, -1)
+ else:
+ predicted_parts = predicted_parts.to("cpu")
+ if correct_parts is None:
+ colors = (
+ predicted_parts[:, :, None] * gray[None, None, :]
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+ )
+ else:
+ correct_parts = correct_parts.to("cpu")
+ colors = (
+ predicted_parts[:, :, None]
+ * (
+ (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+ + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+ + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+ )
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+ )
+
+ separation = 6
- for k in range(1, 2 * separator_size - 8):
- i = k - (separator_size - 4)
- j = separator_size - 5 - abs(i)
- marker[:, :, h // 2 - 1 + i, 2 + j] = 0
- marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
- img = torch.cat(
- [
- img_prompts,
- marker,
- img_answers,
- ],
- dim=3,
- )
+ img_A = self.add_frame(img_A, white[None, :], thickness=2)
+ img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+ img_B = self.add_frame(img_B, white[None, :], thickness=2)
+ img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
+
+ if delta:
+ img_delta_A = self.add_frame(
+ img_delta_A, colors[:, 0], thickness=separation
+ )
+ img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+ img_delta_B = self.add_frame(
+ img_delta_B, colors[:, 0], thickness=separation
+ )
+ img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+ img = torch.cat(
+ [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+ )
+ else:
+ img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+
+ if comments is not None:
+ comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+ comment_img = torch.cat(comment_img, dim=0)
+ img = torch.cat([img, comment_img], dim=2)
image_name = os.path.join(result_dir, filename)
+
torchvision.utils.save_image(
img.float() / 255.0,
image_name,
######################################################################
- def nb_token_values(self):
- return len(self.colors)
-
# @torch.compile
- def rec_coo_(self, nb_rec, min_height=3, min_width=3):
- # @torch.compile
- def overlap(ia, ja, ib, jb):
- return (
- ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1]
- )
+ def rec_coo(
+ self,
+ nb_rec,
+ min_height=3,
+ min_width=3,
+ surface_max=None,
+ prevent_overlap=False,
+ ):
+ if surface_max is None:
+ surface_max = self.height * self.width // 2
- if nb_rec == 3:
+ signature = (nb_rec, min_height, min_width, surface_max)
+
+ try:
+ return self.cache_rec_coo[signature].pop()
+ except IndexError:
+ pass
+ except KeyError:
+ pass
+
+ N = 10000
+ while True:
while True:
- i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values
- j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values
- if (
- not (
- overlap(i[0], j[0], i[1], j[1])
- or overlap(i[0], j[0], i[2], j[2])
- or overlap(i[1], j[1], i[2], j[2])
- )
- and (i[:, 1] - i[:, 0]).min() >= min_height
- and (j[:, 1] - j[:, 0]).min() >= min_width
- ):
+ i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+ j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
+ i[:, 1] += 1
+ j[:, 1] += 1
+ big_enough = (
+ (i[:, 1] >= i[:, 0] + min_height)
+ & (j[:, 1] >= j[:, 0] + min_height)
+ & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
+ )
+
+ i, j = i[big_enough], j[big_enough]
+
+ n = i.size(0) - i.size(0) % nb_rec
+
+ if n > 0:
break
- return (
- (i[0, 0], j[0, 0], i[0, 1], j[0, 1]),
- (i[1, 0], j[1, 0], i[1, 1], j[1, 1]),
- (i[2, 0], j[2, 0], i[2, 1], j[2, 1]),
- )
- # That's quite a tensorial spaghetti mess to sample
- # non-overlapping rectangles quickly, but made the generation of
- # 100k samples go from 1h50 with a lame pure python code to 3min30s
- # with this one.
- # @torch.compile
- def rec_coo(self, nb_rec, min_height=3, min_width=3):
- nb_trials = 200
+ i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+ j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+ if prevent_overlap:
+ can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
+ dim=-1
+ ) <= self.height * self.width
+ i, j = i[can_fit], j[can_fit]
+ if nb_rec == 2:
+ A_i1, A_i2, A_j1, A_j2 = (
+ i[:, 0, 0],
+ i[:, 0, 1],
+ j[:, 0, 0],
+ j[:, 0, 1],
+ )
+ B_i1, B_i2, B_j1, B_j2 = (
+ i[:, 1, 0],
+ i[:, 1, 1],
+ j[:, 1, 0],
+ j[:, 1, 1],
+ )
+ no_overlap = (
+ (A_i1 >= B_i2)
+ | (A_i2 <= B_i1)
+ | (A_j1 >= B_j2)
+ | (A_j2 <= B_j1)
+ )
+ i, j = (i[no_overlap], j[no_overlap])
+ elif nb_rec == 3:
+ A_i1, A_i2, A_j1, A_j2 = (
+ i[:, 0, 0],
+ i[:, 0, 1],
+ j[:, 0, 0],
+ j[:, 0, 1],
+ )
+ B_i1, B_i2, B_j1, B_j2 = (
+ i[:, 1, 0],
+ i[:, 1, 1],
+ j[:, 1, 0],
+ j[:, 1, 1],
+ )
+ C_i1, C_i2, C_j1, C_j2 = (
+ i[:, 2, 0],
+ i[:, 2, 1],
+ j[:, 2, 0],
+ j[:, 2, 1],
+ )
+ no_overlap = (
+ (
+ (A_i1 >= B_i2)
+ | (A_i2 <= B_i1)
+ | (A_j1 >= B_j2)
+ | (A_j2 <= B_j1)
+ )
+ & (
+ (A_i1 >= C_i2)
+ | (A_i2 <= C_i1)
+ | (A_j1 >= C_j2)
+ | (A_j2 <= C_j1)
+ )
+ & (
+ (B_i1 >= C_i2)
+ | (B_i2 <= C_i1)
+ | (B_j1 >= C_j2)
+ | (B_j2 <= C_j1)
+ )
+ )
+ i, j = (i[no_overlap], j[no_overlap])
+ else:
+ assert nb_rec == 1
- while True:
- v = (
+ if i.size(0) > 1:
+ break
+
+ self.cache_rec_coo[signature] = [
+ [
(
- torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
- .sort(dim=-1)
- .indices
- < 2
+ i[n, k, 0].item(),
+ j[n, k, 0].item(),
+ i[n, k, 1].item(),
+ j[n, k, 1].item(),
)
- .long()
- .cumsum(dim=1)
- == 1
- ).long()
+ for k in range(nb_rec)
+ ]
+ for n in range(i.size(0))
+ ]
+
+ return self.cache_rec_coo[signature].pop()
- h = (
+ ######################################################################
+
+ def contact_matrices(self, rn, ri, rj, rz):
+ n = torch.arange(self.nb_rec_max)
+ return (
+ (
(
- torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
- .sort(dim=-1)
- .indices
- < 2
+ (
+ (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
+ | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
+ )
+ & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+ & (rj[:, :, None, 1] >= rj[:, None, :, 0])
)
- .long()
- .cumsum(dim=1)
- == 1
- ).long()
+ | (
+ (
+ (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
+ | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
+ )
+ & (ri[:, :, None, 0] <= ri[:, None, :, 1])
+ & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+ )
+ )
+ # & (rz[:, :, None] == rz[:, None, :])
+ & (n[None, :, None] < rn[:, None, None])
+ & (n[None, None, :] < n[None, :, None])
+ )
- i = torch.logical_and(
- v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+ def sample_rworld_states(self, N=1000):
+ while True:
+ ri = (
+ torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ ri[:, :, 1] += 2
+ rj = (
+ torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ rj[:, :, 1] += 2
+ rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
+ rz = torch.randint(2, (N, self.nb_rec_max))
+ rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
+ n = torch.arange(self.nb_rec_max)
+ nb_collisions = (
+ (
+ (ri[:, :, None, 0] <= ri[:, None, :, 1])
+ & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+ & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+ & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+ & (rz[:, :, None] == rz[:, None, :])
+ & (n[None, :, None] < rn[:, None, None])
+ & (n[None, None, :] < n[None, :, None])
+ )
+ .long()
+ .flatten(1)
+ .sum(dim=1)
)
- v, h = v[i], h[i]
- v = v[: v.size(0) - v.size(0) % nb_rec]
- h = h[: h.size(0) - h.size(0) % nb_rec]
- v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
- h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
+ no_collision = nb_collisions == 0
- r = v[:, :, :, None] * h[:, :, None, :]
+ if no_collision.any():
+ print(no_collision.long().sum() / N)
+ self.rn = rn[no_collision]
+ self.ri = ri[no_collision]
+ self.rj = rj[no_collision]
+ self.rz = rz[no_collision]
+ self.rc = rc[no_collision]
- valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+ nb_contact = (
+ self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+ )
- v = v[valid]
- h = h[valid]
+ self.rcontact = nb_contact > 0
+ self.rfree = torch.full((self.rn.size(0),), True)
- if v.size(0) > 0:
break
- av = torch.arange(v.size(2), device=self.device)[None, :]
- ah = torch.arange(h.size(2), device=self.device)[None, :]
+ def get_recworld_state(self):
+ if not self.rfree.any():
+ self.sample_rworld_states()
+ k = torch.arange(self.rn.size(0))[self.rfree]
+ k = k[torch.randint(k.size(0), (1,))].item()
+ self.rfree[k] = False
+ return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
- return [
- (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
- for i1, j1, i2, j2 in zip(
- v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
- h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
- (v[0] * av).max(dim=-1).values,
- (h[0] * ah).max(dim=-1).values,
- )
- ]
+ def draw_state(self, X, rn, ri, rj, rz, rc):
+ for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
+ X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
- # @torch.compile
- def rec_coo_(self, x, n, min_height=3, min_width=3):
- collision = x.new(x.size())
- while True:
- collision[...] = 0
- result = []
- for _ in range(n):
- while True:
- i1, i2 = torch.randint(x.size(0), (2,))
- if i1 + min_height <= i2:
- break
- while True:
- j1, j2 = torch.randint(x.size(1), (2,))
- if j1 + min_width <= j2:
- break
- collision[i1:i2, j1:j2] += 1
- if collision.max() > 1:
- break
- result.append((i1, j1, i2, j2))
- if collision.max() == 1:
- break
- return result
+ def task_recworld_immobile(self, A, f_A, B, f_B):
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ rn, ri, rj, rz, rc = self.get_recworld_state()
+ self.draw_state(X, rn, ri, rj, rz, rc)
+ ri += 1
+ self.draw_state(f_X, rn, ri, rj, rz, rc)
######################################################################
# @torch.compile
def task_replace_color(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+ # @torch.compile
+ def task_symmetry(self, A, f_A, B, f_B):
+ a, b = torch.randint(2, (2,))
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ if min([x[2] for x in r]) > self.height // 2 + 1:
+ break
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ X[: self.height // 2] = 0
+ f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2]
+ if a == 1:
+ X[...] = X.flip((0,))
+ f_X[...] = f_X.flip((0,))
+ if b == 1:
+ X[...] = X.clone().t()
+ f_X[...] = f_X.clone().t()
+
# @torch.compile
def task_translate(self, A, f_A, B, f_B):
- di, dj = torch.randint(3, (2,)) - 1
+ while True:
+ di, dj = torch.randint(3, (2,)) - 1
+ if di.abs() + dj.abs() > 0:
+ break
+
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
i1, j1, i2, j2 = r[nb_rec - 1]
if (
i1 + di >= 0
def task_grow(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
- direction = torch.randint(2, (1,))
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ direction = torch.randint(2, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
i1, j1, i2, j2 = r[nb_rec - 1]
if i1 + 3 < i2 and j1 + 3 < j2:
break
f_X[i1:i2, j1:j2] = c[n]
# @torch.compile
- def task_color_grow(self, A, f_A, B, f_B):
+ def task_half_fill(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
- direction = torch.randint(4, (1,))
+ c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
+ direction = torch.randint(4, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[2 * n]
# @torch.compile
def task_frame(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
- f_X[i1:i2, j1:j2] = c[n]
if n == nb_rec - 1:
- f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+ f_X[i1:i2, j1] = c[n]
+ f_X[i1:i2, j2 - 1] = c[n]
+ f_X[i1, j1:j2] = c[n]
+ f_X[i2 - 1, j1:j2] = c[n]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
# @torch.compile
def task_detect(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
if n < nb_rec - 1:
- f_X[i1, j1] = c[-1]
+ for k in range(2):
+ f_X[i1 + k, j1] = c[-1]
+ f_X[i1, j1 + k] = c[-1]
# @torch.compile
def contact(self, X, i, j, q):
return no, nq, nq_diag
- # @torch.compile
- def task_count(self, A, f_A, B, f_B):
- N = (torch.randint(4, (1,)) + 2).item()
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ def REMOVED_task_count(self, A, f_A, B, f_B):
+ while True:
+ error = False
+
+ N = 3
+ c = torch.zeros(N + 2, dtype=torch.int64)
+ c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
+ self.cache_count = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 8,
+ nb_iterations=self.height * self.width // 5,
+ )
+ )
- for X, f_X in [(A, f_A), (B, f_B)]:
- nb = torch.zeros(N, dtype=torch.int64)
- q = torch.randint(N, (self.height * self.width,))
- k = torch.randperm(self.height * self.width)
- for p in range(self.height * self.width):
- i, j = k[p] % self.height, k[p] // self.height
- no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
- if no == 0 and nq_diag == 0:
- if nq == 0:
- if nb[q[p]] < self.width:
- X[i, j] = c[q[p]]
- nb[q[p]] += 1
- if nq == 1:
- X[i, j] = c[q[p]]
-
- for n in range(N):
- for j in range(nb[n]):
- f_X[n, j] = c[n]
+ X[...] = self.cache_count.pop()
+
+ # k = (X.max() + 1 + (c.size(0) - 1)).item()
+ # V = torch.arange(k) // (c.size(0) - 1)
+ # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+ # c.size(0) - 1
+ # ) + 1
+
+ V = torch.randint(N, (X.max() + 1,)) + 1
+ V[0] = 0
+ NB = F.one_hot(c[V]).sum(dim=0)
+ X[...] = c[V[X]]
+ f_X[...] = X
+
+ if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+ m = NB[c[:-1]].max()
+ if (NB[c[:-1]] == m).long().sum() == 1:
+ for e in range(1, N + 1):
+ if NB[c[e]] == m:
+ a = (f_X == c[e]).long()
+ f_X[...] = (1 - a) * f_X + a * c[-1]
+ else:
+ error = True
+ break
+
+ if not error:
+ break
+
+ assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
# @torch.compile
def task_trajectory(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
di, dj = torch.randint(7, (2,)) - 3
- i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
if (
abs(di) + abs(dj) > 0
and i + 2 * di >= 0
# @torch.compile
def task_bounce(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
# @torch.compile
def free(i, j):
X[...] = 0
for _ in range((self.height * self.width) // 10):
- i, j = torch.randint(self.height, (1,)), torch.randint(
- self.width, (1,)
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
)
X[i, j] = c[0]
f_X[i, j] = c[0]
if abs(di) + abs(dj) == 1:
break
- i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
X[i, j] = c[1]
f_X[i, j] = c[1]
f_X[i, j] = c[2]
if l <= 1:
X[i, j] = c[2]
+ f_X[i, j] = c[1]
if l >= self.width:
break
# @torch.compile
def task_scale(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
- i, j = torch.randint(self.height // 2, (1,)), torch.randint(
- self.width // 2, (1,)
+ i, j = (
+ torch.randint(self.height // 2, (1,)).item(),
+ torch.randint(self.width // 2, (1,)).item(),
)
for X, f_X in [(A, f_A), (B, f_B)]:
for _ in range(3):
while True:
- i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
- self.width // 2 + 1, (1,)
+ i1, j1 = (
+ torch.randint(self.height // 2 + 1, (1,)).item(),
+ torch.randint(self.width // 2 + 1, (1,)).item(),
)
- i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
- self.width // 2 + 1, (1,)
+ i2, j2 = (
+ torch.randint(self.height // 2 + 1, (1,)).item(),
+ torch.randint(self.width // 2 + 1, (1,)).item(),
)
if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
break
X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
- X[i, j] = c[1]
- f_X[0:2, 0:2] = c[1]
+ for k in range(2):
+ X[i + k, j] = c[1]
+ X[i, j + k] = c[1]
+ f_X[i + k, j] = c[1]
+ f_X[i, j + k] = c[1]
# @torch.compile
def task_symbols(self, A, f_A, B, f_B):
nb_rec = 4
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
delta = 3
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
if d.min() > delta:
break
- for k in range(1, nb_rec):
- X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
ai, aj = i.float().mean(), j.float().mean()
- q = torch.randint(3, (1,)) + 1
-
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+ q = torch.randint(3, (1,)).item() + 1
assert i[q] != ai and j[q] != aj
- X[
+ for Z in [X, f_X]:
+ for k in range(0, nb_rec):
+ Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+ # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+ # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+ # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+ # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+ # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+ f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+ # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+ ii, jj = (
i[0] + delta // 2 + (i[q] - ai).sign().long(),
j[0] + delta // 2 + (j[q] - aj).sign().long(),
- ] = c[nb_rec]
+ )
+
+ X[ii, jj] = c[nb_rec]
+ X[i[0] + delta // 2, jj] = c[nb_rec]
+ X[ii, j[0] + delta // 2] = c[nb_rec]
- f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+ f_X[ii, jj] = c[nb_rec]
+ f_X[i[0] + delta // 2, jj] = c[nb_rec]
+ f_X[ii, j[0] + delta // 2] = c[nb_rec]
# @torch.compile
- def task_ortho(self, A, f_A, B, f_B):
+ def task_isometry(self, A, f_A, B, f_B):
nb_rec = 3
di, dj = torch.randint(3, (2,)) - 1
o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
m = torch.eye(2)
- for _ in range(torch.randint(4, (1,))):
+ for _ in range(torch.randint(4, (1,)).item()):
m = m @ o
if torch.rand(1) < 0.5:
m[0, :] = -m[0, :]
X[...] = 0
f_X[...] = 0
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for r in range(nb_rec):
while True:
):
break
+ def compute_distance(self, walls, goal_i, goal_j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[goal_i, goal_j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ dist[1:-1, 1:-1] = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 1:-1],
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
+
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
# @torch.compile
- def task_islands(self, A, f_A, B, f_B):
- pass
+ def REMOVED_task_distance(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
+ dist0 = torch.empty(self.height + 2, self.width + 2)
+ dist1 = torch.empty(self.height + 2, self.width + 2)
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ nb_rec = torch.randint(3, (1,)).item() + 1
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ X[...] = 0
+ f_X[...] = 0
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[0]
+ f_X[i1:i2, j1:j2] = c[0]
+ while True:
+ i0, j0 = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i0, j0] == 0:
+ break
+ while True:
+ i1, j1 = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i1, j1] == 0:
+ break
+ dist1[...] = 1
+ dist1[1:-1, 1:-1] = (X != 0).long()
+ dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
+ if (
+ dist1[i0 + 1, j0 + 1] >= 1
+ and dist1[i0 + 1, j0 + 1] < self.height * 4
+ ):
+ break
+
+ dist0[...] = 1
+ dist0[1:-1, 1:-1] = (X != 0).long()
+ dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
+
+ dist0 = dist0[1:-1, 1:-1]
+ dist1 = dist1[1:-1, 1:-1]
+
+ D = dist1[i0, j0]
+ for d in range(1, D):
+ M = (dist0 == d) & (dist1 == D - d)
+ f_X[...] = (1 - M) * f_X + M * c[1]
+
+ X[i0, j0] = c[2]
+ f_X[i0, j0] = c[2]
+ X[i1, j1] = c[2]
+ f_X[i1, j1] = c[2]
# for X, f_X in [(A, f_A), (B, f_B)]:
# n = torch.arange(self.height * self.width).reshape(self.height, self.width)
# i,j=q%self.height,q//self.height
# if
- ######################################################################
+ # @torch.compile
+ def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
+ S = 4
+ i0, j0 = (self.height - S) // 2, (self.width - S) // 2
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ f_X[...] = 0
+ h = list(torch.randperm(c.size(0)))
+ n = torch.zeros(c.max() + 1)
+ for _ in range(2):
+ k = torch.randperm(S * S)
+ for q in k:
+ i, j = q % S + i0, q // S + j0
+ if f_X[i, j] == 0:
+ r, s, t, u = (
+ f_X[i - 1, j],
+ f_X[i, j - 1],
+ f_X[i + 1, j],
+ f_X[i, j + 1],
+ )
+ r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
+ if r > 0 and n[r] < 6:
+ n[r] += 1
+ f_X[i, j] = r
+ elif s > 0 and n[s] < 6:
+ n[s] += 1
+ f_X[i, j] = s
+ elif t > 0 and n[t] < 6:
+ n[t] += 1
+ f_X[i, j] = t
+ elif u > 0 and n[u] < 6:
+ n[u] += 1
+ f_X[i, j] = u
+ else:
+ if len(h) > 0:
+ d = c[h.pop()]
+ n[d] += 1
+ f_X[i, j] = d
+
+ if n.sum() == S * S:
+ break
- def all_tasks(self):
- return [
- self.task_replace_color,
- self.task_translate,
- self.task_grow,
- self.task_color_grow,
- self.task_frame,
- self.task_detect,
- self.task_count,
- self.task_trajectory,
- self.task_bounce,
- self.task_scale,
- self.task_symbols,
- self.task_ortho,
- # self.task_islands,
- ]
+ k = 0
+ for d in range(4):
+ while True:
+ ii, jj = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ e = 0
+ for i in range(S):
+ for j in range(S):
+ if (
+ ii + i >= self.height
+ or jj + j >= self.width
+ or (
+ f_X[i + i0, j + j0] == c[d]
+ and X[ii + i, jj + j] > 0
+ )
+ ):
+ e = 1
+ if e == 0:
+ break
+ for i in range(S):
+ for j in range(S):
+ if f_X[i + i0, j + j0] == c[d]:
+ X[ii + i, jj + j] = c[d]
+
+ def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
+ self.cache_islands = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 20,
+ nb_iterations=self.height * self.width // 2,
+ )
+ )
+
+ A = self.cache_islands.pop()
+
+ while True:
+ i, j = (
+ torch.randint(self.height // 2, (1,)).item(),
+ torch.randint(self.width // 2, (1,)).item(),
+ )
+ if A[i, j] > 0:
+ break
+
+ X[...] = (A > 0) * c[0]
+ f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+ f_X[i, j] = X[i, j]
+ X[i, j] = c[1]
+
+ # @torch.compile
+ def TOO_HARD_task_stack(self, A, f_A, B, f_B):
+ N = 5
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ i1, j1, i2, j2 = (
+ self.height // 2 - 1,
+ self.width // 2 - 1,
+ self.height // 2 + 1,
+ self.width // 2 + 1,
+ )
+ op = torch.tensor((0, 1, 2, 3) * 4)
+ op = op[torch.randperm(op.size(0))[:9]]
+ for q in range(op.size(0)):
+ u = 3 * (q // 3)
+ v = 3 * (q % 3)
+ d = c[torch.randint(N, (1,)).item()]
+ # X[u+1,v+1]=d
+ if op[q] == 0: # right
+ X[u : u + 3, v + 2] = d
+ elif op[q] == 1: # let
+ X[u : u + 3, v] = d
+ elif op[q] == 2: # bottom
+ X[u + 2, v : v + 3] = d
+ elif op[q] == 3: # top
+ X[u, v : v + 3] = d
+
+ if q == 0:
+ f_X[i1:i2, j1:j2] = d
+ elif op[q] == 0: # right
+ f_X[i1:i2, j2] = d
+ j2 += 1
+ elif op[q] == 1: # let
+ j1 -= 1
+ f_X[i1:i2, j1] = d
+ elif op[q] == 2: # bottom
+ f_X[i2, j1:j2] = d
+ i2 += 1
+ elif op[q] == 3: # top
+ i1 -= 1
+ f_X[i1, j1:j2] = d
+
+ def randint(self, *m):
+ m = torch.tensor(m)
+ return (torch.rand(m.size()) * m).long()
+
+ def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
+ N = 6
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ M1 = torch.randint(2, (5, 5))
+ M2 = torch.randint(2, (5, 5))
+ P = M1 @ M2
+ for i in range(5):
+ for j in range(5):
+ X[i, j] = c[M1[i, j]]
+ X[i, j + 5] = c[M2[i, j]]
+ f_X[i, j] = c[M1[i, j]]
+ f_X[i, j + 5] = c[M2[i, j]]
+ f_X[i + 5, j + 5] = c[P[i, j]]
+
+ def TOO_HARD_task_compute(self, A, f_A, B, f_B):
+ N = 6
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ v = torch.randint((self.width - 1) // 2, (N,)) + 1
+ chain = torch.randperm(N)
+ eq = []
+ for i in range(chain.size(0) - 1):
+ i1, i2 = chain[i], chain[i + 1]
+ v1, v2 = v[i1], v[i2]
+ k = torch.arange(self.width // 2) + 1
+ d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+ d = d[torch.randint(d.size(0), (1,)).item()]
+ w1, w2 = d
+ eq.append((c[i1], w1, c[i2], w2))
+
+ ii = torch.randperm(self.height - 2)[: len(eq)]
+
+ for k, x in enumerate(eq):
+ i = ii[k]
+ c1, w1, c2, w2 = x
+ s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+ X[i, s : s + w1] = c1
+ X[i, s + w1 : s + w1 + w2] = c2
+ f_X[i, s : s + w1] = c1
+ f_X[i, s + w1 : s + w1 + w2] = c2
+
+ i1, i2 = torch.randperm(N)[:2]
+ v1, v2 = v[i1], v[i2]
+ k = torch.arange(self.width // 2) + 1
+ d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+ d = d[torch.randint(d.size(0), (1,)).item()]
+ w1, w2 = d
+ c1, c2 = c[i1], c[i2]
+ s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+ i = self.height - 1
+ X[i, s : s + w1] = c1
+ X[i, s + w1 : s + w1 + 1] = c2
+ f_X[i, s : s + w1] = c1
+ f_X[i, s + w1 : s + w1 + w2] = c2
+
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_contact(self, A, f_A, B, f_B):
+ def rec_dist(a, b):
+ ai1, aj1, ai2, aj2 = a
+ bi1, bj1, bi2, bj2 = b
+ v = max(ai1 - bi2, bi1 - ai2)
+ h = max(aj1 - bj2, bj1 - aj2)
+ return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+ if min(d[1:]) == 0:
+ break
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ if d[n] == 0:
+ f_X[i1, j1:j2] = c[0]
+ f_X[i2 - 1, j1:j2] = c[0]
+ f_X[i1:i2, j1] = c[0]
+ f_X[i1:i2, j2 - 1] = c[0]
+
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_corners(self, A, f_A, B, f_B):
+ polarity = torch.randint(2, (1,)).item()
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ for k in range(2):
+ if polarity == 0:
+ X[i1 + k, j1] = c[n]
+ X[i2 - 1 - k, j2 - 1] = c[n]
+ X[i1, j1 + k] = c[n]
+ X[i2 - 1, j2 - 1 - k] = c[n]
+ else:
+ X[i1 + k, j2 - 1] = c[n]
+ X[i2 - 1 - k, j1] = c[n]
+ X[i1, j2 - 1 - k] = c[n]
+ X[i2 - 1, j1 + k] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ def compdist(self, X, i, j):
+ dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+ d = dd[1:-1, 1:-1]
+ m = (X > 0).long()
+ d[i, j] = 0
+ e = d.clone()
+ while True:
+ e[...] = d
+ d[...] = (
+ d.min(dd[:-2, 1:-1] + 1)
+ .min(dd[2:, 1:-1] + 1)
+ .min(dd[1:-1, :-2] + 1)
+ .min(dd[1:-1, 2:] + 1)
+ )
+ d[...] = (1 - m) * d + m * self.height * self.width
+ if e.equal(d):
+ break
+
+ return d
+
+ # @torch.compile
+ def task_path(self, A, f_A, B, f_B):
+ nb_rec = 2
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ i1, i2 = torch.randint(self.height, (2,))
+ j1, j2 = torch.randint(self.width, (2,))
+ if (
+ abs(i1 - i2) + abs(j1 - j2) > 2
+ and X[i1, j1] == 0
+ and X[i2, j2] == 0
+ ):
+ d2 = self.compdist(X, i2, j2)
+ d = self.compdist(X, i1, j1)
+
+ if d2[i1, j1] < 2 * self.width:
+ break
+
+ m = ((d + d2) == d[i2, j2]).long()
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+
+ X[i1, j1] = c[-2]
+ X[i2, j2] = c[-2]
+ f_X[i1, j1] = c[-2]
+ f_X[i2, j2] = c[-2]
+
+ # @torch.compile
+ def task_fill(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ accept_full = torch.rand(1) < 0.5
+
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ while True:
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i, j] == 0:
+ break
+
+ d = self.compdist(X, i, j)
+ m = (d < self.height * self.width).long()
+ X[i, j] = c[-1]
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+ f_X[i, j] = 0
+
+ if accept_full or (d * (X == 0)).max() == self.height * self.width:
+ break
+
+ def TOO_HARD_task_addition(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+ N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+ S = N1 + N2
+ for j in range(self.width):
+ r1 = (N1 // (2**j)) % 2
+ X[0, -j - 1] = c[r1]
+ f_X[0, -j - 1] = c[r1]
+ r2 = (N2 // (2**j)) % 2
+ X[1, -j - 1] = c[r2]
+ f_X[1, -j - 1] = c[r2]
+ rs = (S // (2**j)) % 2
+ f_X[2, -j - 1] = c[2 + rs]
+
+ def task_science_implicit(self, A, f_A, B, f_B):
+ nb_rec = 5
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ i1, i2 = torch.randint(self.height, (2,)).sort().values
+ if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+ break
+
+ while True:
+ j1, j2 = torch.randint(self.width, (2,)).sort().values
+ if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+ break
+
+ f_X[i1:i2, j1:j2] = c[0]
+
+ # ---------------------
- def trivial_prompts_and_answers(self, prompts, answers):
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(j1, (1,))
+ X[ii1:ii2, jj:j1] = c[1]
+ f_X[ii1:ii2, jj:j1] = c[1]
+
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+ X[ii1:ii2, j2:jj] = c[2]
+ f_X[ii1:ii2, j2:jj] = c[2]
+
+ # ---------------------
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(i1, (1,))
+ X[ii:i1, jj1:jj2] = c[3]
+ f_X[ii:i1, jj1:jj2] = c[3]
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+ X[i2:ii, jj1:jj2] = c[4]
+ f_X[i2:ii, jj1:jj2] = c[4]
+
+ def task_science_dot(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ q = 0
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ if i >= i1 and i < i2:
+ q += 1
+ f_X[i, j1:j2] = c[-1]
+ if j >= j1 and j < j2:
+ q += 1
+ f_X[i1:i2, j] = c[-1]
+ X[i, j] = c[-1]
+ f_X[i, j] = c[-1]
+ if q >= 2:
+ break
+
+ def collide(self, s, r, rs):
+ i, j = r
+ for i2, j2 in rs:
+ if abs(i - i2) < s and abs(j - j2) < s:
+ return True
+ return False
+
+ def task_science_tag(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ rs = []
+ while len(rs) < 4:
+ i, j = (
+ torch.randint(self.height - 3, (1,)).item(),
+ torch.randint(self.width - 3, (1,)).item(),
+ )
+ if not self.collide(s=3, r=(i, j), rs=rs):
+ rs.append((i, j))
+
+ for k in range(len(rs)):
+ i, j = rs[k]
+ q = min(k, 2)
+ X[i, j : j + 3] = c[q]
+ X[i + 2, j : j + 3] = c[q]
+ X[i : i + 3, j] = c[q]
+ X[i : i + 3, j + 2] = c[q]
+
+ f_X[i, j : j + 3] = c[q]
+ f_X[i + 2, j : j + 3] = c[q]
+ f_X[i : i + 3, j] = c[q]
+ f_X[i : i + 3, j + 2] = c[q]
+ if q == 2:
+ f_X[i + 1, j + 1] = c[-1]
+
+ # end_tasks
+
+ ######################################################################
+
+ def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
S = self.height * self.width
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
- f_Bs = answers
- return (Bs == f_Bs).long().min(dim=-1).values > 0
+ quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+ quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+ quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+ quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+ quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
- def generate_prompts_and_answers(
- self, nb, tasks=None, progress_bar=False, device="cpu"
- ):
- if tasks is None:
- tasks = self.all_tasks()
+ return quizzes
+ def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
S = self.height * self.width
- prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
- answers = torch.zeros(nb, S, dtype=torch.int64)
- bunch = zip(prompts, answers)
+ if tasks is None:
+ tasks = self.all_tasks
+
+ quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
if progress_bar:
- bunch = tqdm.tqdm(
- bunch,
+ quizzes = tqdm.tqdm(
+ quizzes,
dynamic_ncols=True,
- desc="world generation",
- total=prompts.size(0),
+ desc="world quizzes generation",
+ total=quizzes.size(0),
)
- for prompt, answer in bunch:
- A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
- f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
- B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
- f_B = answer.view(self.height, self.width)
- task = tasks[torch.randint(len(tasks), (1,))]
+ for quiz in quizzes:
+ q = quiz.reshape(4, self.height, self.width)
+ q[...] = 0
+ A, f_A, B, f_B = q
+ task = tasks[torch.randint(len(tasks), (1,)).item()]
task(A, f_A, B, f_B)
- return prompts.flatten(1), answers.flatten(1)
+ return quizzes
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- nrow=4,
- ):
- self.save_image(
- result_dir,
- filename_prefix + ".png",
- prompts,
- answers,
- predicted_prompts,
- predicted_answers,
- nrow,
- )
+ def save_some_examples(self, result_dir, prefix=""):
+ nb, nrow = 256, 8
+ for t in self.all_tasks:
+ print(t.__name__)
+ quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+ self.save_quizzes_as_image(
+ result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
+ )
######################################################################
if __name__ == "__main__":
import time
+ # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
+
grids = Grids()
- # nb = 1000
- # grids = problem.MultiThreadProblem(
- # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
- # )
- # time.sleep(10)
- # start_time = time.perf_counter()
- # prompts, answers = grids.generate_prompts_and_answers(nb)
- # delay = time.perf_counter() - start_time
- # print(f"{prompts.size(0)/delay:02f} seq/s")
- # exit(0)
-
- if True:
- nb = 72
-
- for t in grids.all_tasks():
- # for t in [grids.task_ortho]:
- print(t.__name__)
- prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
- grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+ nb, nrow = 64, 4
+ # nb, nrow = 8, 2
+
+ # for t in grids.all_tasks:
+
+ for t in [
+ grids.task_replace_color,
+ grids.task_translate,
+ grids.task_grow,
+ grids.task_frame,
+ ]:
+ print(t.__name__)
+ w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+
+ # w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ t.__name__ + ".png",
+ w_quizzes,
+ delta=True,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ )
+
+ exit(0)
- exit(0)
+ q = grids.text2quiz(
+ """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+ )
+
+ grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)
- nb = 500
+ exit(0)
- for t in grids.all_tasks():
+ nb = 1000
+
+ for t in [
+ # grids.task_bounce,
+ # grids.task_contact,
+ # grids.task_corners,
+ # grids.task_detect,
+ # grids.task_fill,
+ # grids.task_frame,
+ # grids.task_grow,
+ # grids.task_half_fill,
+ # grids.task_isometry,
+ # grids.task_path,
+ # grids.task_replace_color,
+ # grids.task_scale,
+ grids.task_symbols,
+ # grids.task_trajectory,
+ # grids.task_translate,
+ ]:
+ # for t in [grids.task_path]:
start_time = time.perf_counter()
- prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+ w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
delay = time.perf_counter() - start_time
- print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
+ print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
+ grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
exit(0)
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quizzes(
+ grids.save_quizzes_as_image(
"/tmp",
- "test",
+ "test.png",
prompts[:nb],
answers[:nb],
# You can add a bool to put a frame around the predicted parts
# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, argparse, time, tqdm, os, datetime, warnings
+import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
import torch, torchvision
from torch import nn
from torch.nn import functional as F
-import ffutils
+import ffutils, grids, attae
-import mygpt
-import sky, grids, quiz_machine
-from problem import MultiThreadProblem
+import threading, subprocess
-# world quizzes vs. culture quizzes
+# import torch.multiprocessing as mp
-######################################################################
+torch.set_float32_matmul_precision("high")
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
+# torch.set_default_dtype(torch.bfloat16)
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--seed", type=int, default=0)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--resume", action="store_true", default=False)
-########################################
+# ----------------------------------
parser.add_argument("--nb_epochs", type=int, default=10000)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
+
+parser.add_argument("--train_batch_size", type=int, default=None)
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--eval_batch_size", type=int, default=25)
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
+
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
+
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=5e-4)
-########################################
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+# ----------------------------------
+
+parser.add_argument("--model_type", type=str, default="standard")
+
+parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--nb_blocks", type=int, default=None)
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
-########################################
+# ----------------------------------
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--nb_threads", type=int, default=1)
-parser.add_argument("--problem", type=str, default="grids")
+parser.add_argument("--gpus", type=str, default="all")
-parser.add_argument("--multi_thread_problem", action="store_true", default=False)
+# ----------------------------------
-parser.add_argument("--nb_gpts", type=int, default=5)
+parser.add_argument("--nb_models", type=int, default=5)
-parser.add_argument("--min_to_validate", type=int, default=None)
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
-parser.add_argument("--max_to_validate", type=int, default=None)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--deterministic_validation", action="store_true", default=False)
+parser.add_argument("--proba_hint", type=float, default=0.25)
-parser.add_argument("--bidirectional_validation", action="store_true", default=False)
-
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--quizzes", type=str, default=None)
######################################################################
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
+grids_tasks = ", ".join(
+ [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
-parser.add_argument("--sky_speed", type=int, default=3)
+parser.add_argument(
+ "--grids_world_tasks",
+ type=str,
+ default="replace_color,translate,grow,frame",
+ help="A comma-separated subset of: " + grids_tasks + ".",
+)
######################################################################
args = parser.parse_args()
-if args.min_to_validate is None:
- args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
- args.max_to_validate = args.nb_gpts - 1
-
if args.result_dir is None:
args.result_dir = f"results_culture"
######################################################################
-default_args = {
- "model": "37M",
- "batch_size": 100,
- "nb_train_samples": 100000,
- "nb_test_samples": 10000,
-}
-
-for k, v in default_args.items():
- if getattr(args, k) is None:
- setattr(args, k, v)
-
-######################################################################
-
default_model_args = {
"17K": {
"dim_model": 32,
######################################################################
-try:
- os.mkdir(args.result_dir)
-except FileExistsError:
- print(f"result directory {args.result_dir} already exists")
- exit(1)
+if args.resume:
+ if not os.path.isdir(args.result_dir):
+ print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
+ exit(1)
+else:
+ try:
+ os.mkdir(args.result_dir)
+ except FileExistsError:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
def log_string(s):
+ """print the given string prefixed with a time stamps, and log it
+ into log_file is not None"""
+
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
sys.stdout.flush()
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+ f.write(f"{' '.join(sys.argv)}\n")
+
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
+
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
######################################################################
-if args.dirty_debug:
- args.nb_train_samples = 2500
- args.nb_test_samples = 100
+if args.gpus == "all":
+ gpus_idx = range(torch.cuda.device_count())
+else:
+ gpus_idx = [int(k) for k in args.gpus.split(",")]
+
+gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+
+if torch.cuda.is_available():
+ main_device = gpus[0]
+else:
+ assert len(gpus) == 0
+ main_device = torch.device("cpu")
-if args.physical_batch_size is None:
- args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+ args.train_batch_size = args.batch_size
else:
- assert args.batch_size % args.physical_batch_size == 0
+ assert args.batch_size % args.train_batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.problem == "sky":
- problem = sky.Sky(
- height=args.sky_height,
- width=args.sky_width,
- nb_birds=args.sky_nb_birds,
- nb_iterations=args.sky_nb_iterations,
- speed=args.sky_speed,
- )
- back_accuracy = False
-elif args.problem == "grids":
- problem = grids.Grids(device=device)
- back_accuracy = True
-else:
- raise ValueError
-
-if args.multi_thread_problem:
- problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000)
-
-quiz_machine = quiz_machine.QuizMachine(
- problem=problem,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- back_accuracy=back_accuracy,
- batch_size=args.physical_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=device,
-)
+######################################################################
+
+
+def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
######################################################################
-log_string(f"device {device}")
-vocabulary_size = quiz_machine.vocabulary_size()
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+ if c_quizzes is None:
+ quizzes = problem.generate_w_quizzes(nb_samples)
+ nb_w_quizzes = quizzes.size(0)
+ nb_c_quizzes = 0
+ else:
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
+
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
+
+ w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ nb_w_quizzes = w_quizzes.size(0)
+ nb_c_quizzes = c_quizzes.size(0)
+
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes = quizzes[i].contiguous()
+
+ log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+ return quizzes
-log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
-# Compute the entropy of the training tokens
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
+ input, masks, targets = imt_set.unbind(dim=1)
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.proba_hint
+ ).long() * masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
+ input, masks, targets = imt_set.unbind(dim=1)
+ noise = problem.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
+# Prediction
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+def samples_for_prediction_imt(input):
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-##############################
+def ae_predict(model, imt_set, local_device=main_device):
+ model.eval().to(local_device)
-def one_epoch(model, quiz_machine):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ record = []
- model.train()
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
+ )
- nb_train_samples, acc_train_loss = 0, 0.0
+ for imt in src:
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
- for input in quiz_machine.batches(split="train"):
- input = input.to(device)
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
+ record.append(result)
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
+ return torch.cat(record)
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
+def predict_the_four_grids(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- loss.backward()
+ if with_hints:
+ imt_set = add_hints_imt(imt_set)
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
+ if with_noise:
+ imt_set = add_noise_imt(imt_set)
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ result = ae_predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ return result
######################################################################
-def run_tests(model, quiz_machine, deterministic_synthesis):
- with torch.autograd.no_grad():
- model.eval()
+def samples_for_generation_imt(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
+ noise = problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ masks = input.new_full(input.size(), 1)
- for input in quiz_machine.batches(split="test"):
- input = input.to(device)
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+def ae_generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+ # We loop through the iterations first and through the
+ # mini-batches second so that we keep only the samples that have
+ # not stabilized
- log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+ all_input = problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
- model.main_test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
+ for it in range(args.diffusion_nb_iterations):
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+ if not all_changed.any():
+ break
+
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),
+ )
+
+ for input, masks, changed in src:
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+ update = (1 - mask_changes) * input + mask_changes * output
+ changed[...] = changed & (update != input).max(dim=1).values
+ input[...] = update
+
+ a = all_changed.clone()
+ all_input[a] = sub_input
+ all_masks[a] = sub_masks
+ all_changed[a] = sub_changed
+
+ return all_input
+
+
+######################################################################
+
+
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
+ quizzes = generate_quiz_set(
+ args.nb_train_samples if train else args.nb_test_samples,
+ c_quizzes,
+ args.c_quiz_multiplier,
+ )
+
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
+ b_p = samples_for_prediction_imt(q_p)
+ b_p = add_noise_imt(b_p)
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
+
+ # The other half are denoising examples for the generation
+ b_g = samples_for_generation_imt(q_g)
+
+ imt_set = torch.cat([b_p, b_g])
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
+ else:
+ label = "test"
+ model.eval().to(local_device)
+ batch_size = args.eval_batch_size
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for imt in tqdm.tqdm(
+ imt_set.split(batch_size),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // batch_size,
+ delay=10,
+ ):
+ input, masks, targets = imt.unbind(dim=1)
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
)
+ loss = (loss_per_token * masks).mean()
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
+
+ if train:
+ loss.backward()
+
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
######################################################################
-def valid_c_quizzes(recorded, criteria):
- result = [q[criteria(c)] for q, c in recorded]
- return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+ # Save some images of the prediction results
+
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
+
+ correct = (quizzes == result).min(dim=1).values.long()
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+ :, :, 1
+ ]
+ predicted_parts = correct_parts.abs()
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
+ )
+
+ # Save some images of the ex nihilo generation of the four grids
+
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
######################################################################
-def create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=1000,
- nb_for_test=100,
+def one_complete_epoch(
+ model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
):
- quizzes_and_nb_correct_records = []
+ one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
+
+ one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
+
+ # Compute the test accuracy
+
+ quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ correct = (quizzes == result).min(dim=1).values.long()
- nb_to_create = nb_for_train + nb_for_test
+ nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
+ model.test_accuracy = nb_correct / nb_total
- # ------------------------------------------------------------
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
+ )
- standard_validity = lambda nb_correct: torch.logical_and(
- nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+ save_inference_images(
+ model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
)
- file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
- with open(file_name, "w") as logp_file:
- while (
- valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
- < nb_to_create
- ):
- # Select a model at random to generate the new quizzes
+######################################################################
+
- model_for_generation = models[torch.randint(len(models), (1,))]
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+ return (
+ (prediction != quizzes)
+ .long()
+ .reshape(quizzes.size(0), 4, -1)
+ .sum(dim=2)
+ .max(dim=1)
+ .values
+ )
- c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
- model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
- )
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
+ nb_correct, nb_wrong = 0, 0
- if c_quizzes.size(0) > 0:
- nb_correct, seq_logproba = quiz_machine.compute_correctness(
- c_quizzes,
- models,
- bidirectional_validation=args.bidirectional_validation,
- deterministic_validation=args.deterministic_validation,
- )
+ for model in models:
+ model = copy.deepcopy(model).to(local_device).eval()
+ predicted = predict_the_four_grids(
+ model=model,
+ input=quizzes,
+ with_noise=False,
+ with_hints=with_hints,
+ local_device=local_device,
+ )
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
+ nb_correct += (nb_mistakes == 0).long()
+ nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
+
+ # print("\n\n", nb_correct, nb_wrong)
+
+ return nb_correct, nb_wrong
+
+
+######################################################################
- for n, l in zip(nb_correct, seq_logproba):
- s = " ".join([str(x.item()) for x in l])
- logp_file.write(f"{n} {s}\n")
- if args.dirty_debug:
- nb_correct = torch.randint(
- len(models) + 1, nb_correct.size(), device=c_quizzes.device
- )
+def identity_quizzes(quizzes):
+ quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
+ quizzes[:, 2] == quizzes[:, 3]
+ ).min(dim=1).values
- quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
- nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
- nv = " ".join([str(x.item()) for x in nv])
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
+ record = []
+ nb_validated = 0
+
+ start_time = time.perf_counter()
+ last_log = -1
+
+ while nb_validated < nb_to_generate:
+ # Generate new quizzes
+
+ model = models[torch.randint(len(models), (1,)).item()]
+ model = copy.deepcopy(model).to(local_device).eval()
+ generator_id = model.id
+
+ c_quizzes = ae_generate(
+ model=model, nb=args.eval_batch_size * 10, local_device=local_device
+ )
+
+ c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
- nb_validated = valid_c_quizzes(
- quizzes_and_nb_correct_records, standard_validity
- ).size(0)
+ if c_quizzes.size(0) > 0:
+ # Select the ones that are solved properly by some models and
+ # not understood by others
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=True,
+ local_device=local_device,
+ )
+
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
+
+ nb_validated += to_keep.long().sum().item()
+ record.append(c_quizzes[to_keep])
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ if last_log < 0 or duration > last_log + 10:
+ last_log = duration
+ if nb_validated > 0:
+ if nb_validated < nb_to_generate:
+ d = (nb_to_generate - nb_validated) * duration / nb_validated
+ e = (
+ datetime.datetime.now() + datetime.timedelta(seconds=d)
+ ).strftime("%a %H:%M")
+ else:
+ e = "now!"
+ else:
+ e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
+ f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
)
- # store the new c_quizzes which have been validated
+ #####################
+
+ duration = time.perf_counter() - start_time
- new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity)
+ log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
- quiz_machine.reverse_random_half_in_place(new_c_quizzes)
+ return torch.cat(record).to("cpu")
+
+
+######################################################################
- quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
- quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
- # save a bunch of images to investigate what quizzes with a
- # certain nb of correct predictions look like
+def multithread_execution(fun, arguments):
+ # Single instance, no thread
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
- for n in range(len(models) + 1):
- s = (
- "_validated"
- if n >= args.min_to_validate and n <= args.max_to_validate
- else ""
+ records, threads = [], []
+
+ def threadable_fun(*args):
+ r = fun(*args)
+ if type(r) is not tuple:
+ r = (r,)
+ records.append(r)
+
+ for args in arguments:
+ # To get a different sequence between threads
+ log_string(f"dummy_rand {torch.rand(1)}")
+ # torch.rand(1)
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ if records[0] == (None,):
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+ if suffix != "":
+ suffix = "_" + suffix
+
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
)
- q = valid_c_quizzes(
- quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n
- )[:72]
+ log_string(f"wrote ae_*{suffix}.pth")
- quiz_machine.reverse_random_half_in_place(q)
- if q.size(0) > 0:
- quiz_machine.save_quizzes(
- args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q
- )
+######################################################################
+
+
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
+ c_quizzes = c_quizzes.to(local_device)
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
+ )
+
+ comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=c_quizzes,
+ comments=comments,
+ delta=True,
+ nrow=8,
+ )
+
+ log_string(f"wrote {filename}")
+######################################################################
+
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
######################################################################
models = []
-for k in range(args.nb_gpts):
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
+if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+
+for i in range(args.nb_models):
+ model = model_constructor(
+ vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
dropout=args.dropout,
- ).to(device)
+ )
+
+ # model = torch.compile(model)
- model.main_test_accuracy = 0.0
- model.id = k
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
models.append(model)
+######################################################################
+
+current_epoch = 0
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+if args.resume:
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
+
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ log_string(f"successfully loaded {filename}")
+
+ filename = "state.pth"
+ state = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+
+ log_string(f"successfully loaded {filename}")
+
+ current_epoch = state["current_epoch"]
+ train_c_quizzes = state["train_c_quizzes"]
+ test_c_quizzes = state["test_c_quizzes"]
######################################################################
-nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
-log_string(
- f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
-)
######################################################################
-if args.dirty_debug:
- args.accuracy_to_make_c_quizzes = 0.0
- args.nb_gpts = 2
- nb_new_c_quizzes_for_train = 100
- nb_new_c_quizzes_for_test = 10
+train_c_quizzes, test_c_quizzes = None, None
######################################################################
-for n_epoch in range(args.nb_epochs):
- log_string(f"--- epoch {n_epoch} ----------------------------------------")
+for n_epoch in range(current_epoch, args.nb_epochs):
+ start_time = time.perf_counter()
- cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
- log_string(f"current_test_accuracies {cta}")
+ state = {
+ "current_epoch": n_epoch,
+ "train_c_quizzes": train_c_quizzes,
+ "test_c_quizzes": test_c_quizzes,
+ }
- ##################################################
- # Select, improve, and eval the worst model
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
- weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
- log_string(
- f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
- )
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
- one_epoch(weakest_model, quiz_machine)
+ # --------------------------------------------------------------------
- log_string(
- f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
- run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+ if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
+ if train_c_quizzes is None:
+ save_models(models, "naive")
- log_string(
- f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ nb_gpus = len(gpus)
+ nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
- ##################################################
- # Replace a fraction of the w_quizzes with fresh ones
+ (new_c_quizzes,) = multithread_execution(
+ generate_c_quizzes,
+ [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+ )
+
+ save_quiz_image(
+ models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
+ )
- quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
- ##################################################
- # If all the models are good enough, generate new quizzes and
- # re-compute the test errors
+ train_c_quizzes = (
+ new_c_quizzes
+ if train_c_quizzes is None
+ else torch.cat([train_c_quizzes, new_c_quizzes])
+ )
+ train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
- if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=nb_new_c_quizzes_for_train,
- nb_for_test=nb_new_c_quizzes_for_test,
+ nb_correct, _ = evaluate_quizzes(
+ quizzes=train_c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
)
+ test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
+
for model in models:
- run_tests(model, quiz_machine, deterministic_synthesis=False)
+ model.test_accuracy = 0
+ if train_c_quizzes is None:
+ log_string("no_c_quiz")
+ else:
+ log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
-######################################################################
+ # --------------------------------------------------------------------
+
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ weakest_models = ranked_models[: len(gpus)]
+
+ log_string(
+ f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+ )
+
+ multithread_execution(
+ one_complete_epoch,
+ [
+ (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
+ for model, gpu in zip(weakest_models, gpus)
+ ],
+ )
+
+ save_models(models)
+
+ # --------------------------------------------------------------------
+
+ duration = time.perf_counter() - start_time
+ str_duration = ""
+ if duration >= 60:
+ str_duration += f"{int(duration)//60}min"
+ str_duration += f"{int(duration)%60}s"
+ str_next = (
+ datetime.datetime.now() + datetime.timedelta(seconds=duration)
+ ).strftime("%H:%M:%S")
+ log_string(f"epoch_duration {str_duration} next_finish {str_next}")
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-# This is an implementation from scratch of a "GPT", that is a model
-# composed of several causal self-attention blocks. It is equipped
-# with a caching mechanism for keys and values to avoid a O(N^3) cost
-# for auto-regression.
-
-import math
-
-import torch
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-# A BracketedSequence is a BxTx... tensor with a first and a nb time
-# steps to compute.
-
-# Modules able to process it expect that they will have to process a
-# first bracket starting at t=0, followed by a succession of brackets
-# that move forward in time, do not overlap, and cover the axis T with
-# no holes.
-#
-# Although it is more general, for a classical prompt-conditioned
-# auto-regressive process it will be a first bracket starting at 0 and
-# of arbitrary length for the "prompt", followed by brackets of length
-# 1 for the successive tokens.
-#
-# Modules able to process brackets may implement a cache that is
-# resetted when the input bracket starts at t=0
-
-
-class BracketedSequence:
- def __init__(self, x, first=None, nb=None):
- self.x = x
- self.first = 0 if first is None else first
- self.nb = x.size(1) if nb is None else nb
-
- def slice(self):
- return self.x[:, self.first : self.first + self.nb]
-
- def complete(self):
- return self.first == 0 and self.nb == self.x.size(1)
-
-
-######################################################################
-
-
-class CacheWrapper(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- if bs.first == 0:
- y = self.f(bs.slice())
- self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
- self.cache_y[:, bs.first : bs.first + bs.nb] = y
- else:
- self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
-
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
-
-
-##############################
-
-
-class AddPositionalEncoding(nn.Module):
- def __init__(self, len_max):
- super().__init__()
- self.len_max = len_max
-
- # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
- def forward(self, bs):
- if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
- None, :
- ]
- k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
- )
- self.cache_y = bs.x.new(bs.x.size())
-
- self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
- )
-
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class QKVAttention(nn.Module):
- def __init__(
- self,
- dim_in,
- dim_qk,
- dim_v,
- nb_heads=1,
- causal=False,
- attention_dropout=0.0,
- ):
- super().__init__()
-
- def randw(*d):
- return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
-
- self.causal = causal
- self.attention_dropout = attention_dropout
- self.record_attention = False
-
- self.w_q = randw(nb_heads, dim_qk, dim_in)
- self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
-
- def forward(self, bs_q):
- x_q = bs_q.x
-
- assert (
- self.causal or bs_q.complete()
- ), "Partial evaluation is only possible for causal models"
-
- if bs_q.first == 0:
- self.cache_k = x_q.new_zeros(
- x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
- )
- self.cache_v = x_q.new_zeros(
- x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
- )
- self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
-
- q = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
- )
-
- self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
- )
- self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
- )
-
- a = torch.einsum(
- "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
- ) / math.sqrt(self.w_q.size(1))
-
- if self.causal:
- if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
- a = a.masked_fill(
- self.cache_attzero[
- :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
- ],
- float("-inf"),
- )
-
- a = a.softmax(dim=3)
-
- if self.record_attention:
- self.a = a
-
- a = F.dropout(a, self.attention_dropout, self.training)
-
- y = torch.einsum(
- "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
- ).flatten(2)
-
- self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
-
- return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
-
-
-##############################
-
-
-class NoiseInjector(nn.Module):
- def __init__(self):
- super().__init__()
- self.noise_std = 0.0
-
- def forward(self, x):
- if self.noise_std > 0:
- x = x + torch.randn(x.size(), device=x.device) * self.noise_std
- return x
-
-
-def set_noise_injection(model, noise_std):
- for m in model.modules():
- if isinstance(m, NoiseInjector):
- m.noise_std = noise_std
-
-
-##############################
-
-
-class MyGPT(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- causal=False,
- dropout=0.0,
- len_max=1e5,
- ):
- super().__init__()
-
- assert dim_model % nb_heads == 0
-
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
- )
-
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- NoiseInjector(),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- causal=causal,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- NoiseInjector(),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
- ]
-
- self.trunk = nn.Sequential(*trunk_blocks)
-
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
-
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
-
- def forward(self, bs):
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
- bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- bs = self.embedding(bs)
- bs = self.trunk(bs)
- bs = self.readout(bs)
- return bs
-
- def record_attention(self, v=True):
- for m in self.modules():
- if isinstance(m, QKVAttention):
- m.record_attention = v
-
- def retrieve_attention(self):
- a = []
- for m in self.modules():
- if isinstance(m, QKVAttention):
- a.append(m.a)
- return a
-
-
-######################################################################
-
-if __name__ == "__main__":
- print("Basic check.")
-
- vocabulary_size = 3
- x = torch.randint(vocabulary_size, (1, 5))
-
- model = MyGPT(
- vocabulary_size=vocabulary_size,
- dim_model=4,
- dim_keys=2,
- dim_hidden=2,
- nb_heads=2,
- nb_blocks=2,
- dropout=0.1,
- causal=True,
- )
-
- model.eval()
- y1 = model(BracketedSequence(x)).x
- y2 = torch.randn_like(y1)
- for s in range(x.size(1)):
- z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.slice()
-
- print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
-
-######################################################################
class Problem:
- def nb_token_values(self):
- pass
-
- def trivial_prompts_and_answers(self, prompts, answers):
- pass
-
- # returns two tensors nb x D and nb x D'
- def generate_prompts_and_answers(self, nb):
- pass
-
- # save a file to vizualize quizzes, you can save a txt or png file
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- pass
-
-
-class MultiThreadProblem:
- def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
- self.problem = problem
- self.chunk_size = chunk_size
- self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
- for _ in range(nb_threads):
- threading.Thread(target=self.fill_cache, daemon=True).start()
- self.rest = None
-
- def nb_token_values(self):
- return self.problem.nb_token_values()
+ def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+ if nb_threads > 0:
+ self.chunk_size = chunk_size
+ self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+ for _ in range(nb_threads):
+ threading.Thread(target=self.fill_cache, daemon=True).start()
+ self.rest = None
+ else:
+ self.queue = None
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- self.problem.save_quizzes(
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- )
+ def nb_cached_quizzes(self):
+ if self.queue is None:
+ return None
+ else:
+ return self.queue.qsize() * self.chunk_size
def fill_cache(self):
while True:
- prompts, answers = self.problem.generate_prompts_and_answers(
- self.chunk_size
- )
+ quizzes = self.generate_w_quizzes_(self.chunk_size)
+ self.queue.put(quizzes.to("cpu"), block=True)
- self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+ def generate_w_quizzes(self, nb, progress_bar=True):
+ if self.queue is None:
+ return self.generate_w_quizzes_(nb)
- def trivial_prompts_and_answers(self, prompts, answers):
- return self.problem.trivial_prompts_and_answers(prompts, answers)
-
- def generate_prompts_and_answers(self, nb):
if self.rest is not None:
- prompts, answers = rest
+ quizzes = rest
else:
- prompts, answers = [], []
+ quizzes = []
self.rest = None
- n = sum([p.size(0) for p in prompts])
-
- with tqdm.tqdm(
- total=nb,
- dynamic_ncols=True,
- desc="world generation",
- ) as pbar:
+ n = sum([q.size(0) for q in quizzes])
+
+ if progress_bar:
+ with tqdm.tqdm(
+ total=nb, dynamic_ncols=True, desc="world generation", delay=10
+ ) as pbar:
+ while n < nb:
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+ pbar.update(q.size(0))
+ else:
while n < nb:
- p, s = self.queue.get(block=True)
- prompts.append(p)
- answers.append(s)
- n += p.size(0)
- pbar.update(p.size(0))
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
- prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
- assert n == prompts.size(0)
+ quizzes = torch.cat(quizzes, dim=0)
+ assert n == quizzes.size(0)
k = n - nb
if k > 0:
- rest = (prompts[-k:], answers[-k:])
- prompts, answers = prompts[:-k], answers[:-k]
+ rest = quizzes[-k:]
+ quizzes = quizzes[:-k]
+
+ return quizzes
+
+ ######################################################################
+
+ def trivial_prompts_and_answers(self, prompts, answers):
+ pass
+
+ # The one to implement, returns two tensors nb x D and nb x D'
+ def generate_w_quizzes_(self, nb):
+ pass
+
+ # save a file to vizualize quizzes, you can save a txt or png file
+ def save_quiz_illustrations(
+ self,
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ pass
+
+ def save_some_examples(self, result_dir):
+ pass
- return prompts, answers
+ ######################################################################
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
- model,
- input,
- ar_mask,
- seq_logproba,
- temperature,
- deterministic_synthesis,
-):
- to_generate = (ar_mask.sum(0) > 0).nonzero()
-
- if to_generate.min() > 0:
- model(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = model(BracketedSequence(input, s, 1)).x
-
- logits = output[:, s]
-
- logits = (logits / temperature).log_softmax(dim=-1)
-
- if deterministic_synthesis:
- t_next = logits.argmax(-1)
- else:
- dist = torch.distributions.categorical.Categorical(logits=logits)
- t_next = dist.sample()
-
- all_n = torch.arange(t_next.size(0))
-
- seq_logproba += logits[all_n, t_next]
-
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-def masked_inplace_autoregression(
- model,
- batch_size,
- input,
- ar_mask,
- seq_logproba,
- temperature,
- deterministic_synthesis,
- forbidden_tokens=None,
- logit_biases=None,
- progress_bar_desc=None,
- device=torch.device("cpu"),
-):
- assert input.size() == ar_mask.size()
-
- batches = zip(
- input.split(batch_size),
- ar_mask.split(batch_size),
- seq_logproba.split(batch_size),
- )
-
- if progress_bar_desc is not None:
- batches = tqdm.tqdm(
- batches,
- dynamic_ncols=True,
- desc=progress_bar_desc,
- total=(input.size(0) + batch_size - 1) // batch_size,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, ar_mask, seq_logproba in batches:
- one_batch_masked_inplace_autoregression(
- model=model,
- input=input,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=deterministic_synthesis,
- )
-
- model.train(t)
-
-
-######################################################################
-
-
-class QuizMachine:
- def indices_forward_and_backward(self, quizzes):
- i_forward = quizzes[:, 0] == self.token_forward
- j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
- i_backward = quizzes[:, 0] == self.token_backward
- j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
- assert torch.logical_or(
- torch.logical_and(i_forward, j_forward),
- torch.logical_and(i_backward, j_backward),
- ).all()
- return i_forward, i_backward
-
- def non_trivial(self, quizzes):
- quizzes = quizzes.clone()
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
- return torch.logical_not(
- self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
- )
- )
-
- def reverse_time(self, quizzes):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
- forward_to_backward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
- quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
- quizzes[:, 1 : 1 + self.prompt_len],
- ],
- dim=1,
- )
-
- forward_to_backward[:, 0] = self.token_backward
- forward_to_backward[:, 1 + self.answer_len] = self.token_backward
-
- backward_to_forward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.answer_len :],
- quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
- quizzes[:, 1 : 1 + self.answer_len],
- ],
- dim=1,
- )
-
- backward_to_forward[:, 0] = self.token_forward
- backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
-
- m = i_forward.long()[:, None]
-
- return m * forward_to_backward + (1 - m) * backward_to_forward
-
- def reverse_random_half_in_place(self, quizzes):
- i = torch.rand(quizzes.size(0)) < 0.5
- if i.any():
- quizzes[i] = self.reverse_time(quizzes[i])
-
- def make_ar_mask(self, quizzes, first=False):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
- t = torch.arange(quizzes.size(1), device=quizzes.device)
-
- if first:
- m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
- m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
- else:
- m_forward = (t >= 2 + self.prompt_len).long()
- m_backward = (t >= 2 + self.answer_len).long()
-
- m = i_forward.long()[:, None]
-
- return m * m_forward + (1 - m) * m_backward
-
- def generate_token_sequences(self, nb):
- prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
- if self.prompt_len is None:
- self.prompt_len = prompts.size(1)
-
- if self.answer_len is None:
- self.answer_len = answers.size(1)
-
- assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
- result = []
-
- for prompt, answer in zip(prompts, answers):
- a = [
- torch.tensor([self.token_forward]),
- prompt,
- torch.tensor([self.token_forward]),
- answer,
- ]
-
- result.append(torch.cat(a, dim=0)[None, :])
-
- return torch.cat(result, dim=0)
-
- def __init__(
- self,
- problem,
- nb_train_samples,
- nb_test_samples,
- back_accuracy,
- batch_size,
- result_dir,
- logger,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- v = problem.nb_token_values()
- self.token_forward = v
- self.token_backward = v + 1
- self.nb_token_values = v + 2
-
- self.problem = problem
- self.back_accuracy = back_accuracy
- self.batch_size = batch_size
- self.device = device
- self.logger = logger
- self.prompt_len = None
- self.answer_len = None
-
- self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- self.reverse_random_half_in_place(self.train_w_quizzes)
- self.train_w_quizzes = self.train_w_quizzes.to(device)
-
- self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- self.reverse_random_half_in_place(self.test_w_quizzes)
- self.test_w_quizzes = self.test_w_quizzes.to(device)
-
- self.train_c_quizzes = []
- self.test_c_quizzes = []
-
- if result_dir is not None:
- self.save_quizzes(
- result_dir,
- "culture_w_quizzes",
- self.train_w_quizzes[:72],
- )
-
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- quizzes,
- mistakes=None,
- ):
- quizzes = quizzes.clone()
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-
- predicted_prompts = n_backward.long()
- predicted_answers = 1 - predicted_prompts
- if mistakes is not None:
- # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes
- predicted_answers *= mistakes
- else:
- # 0/2 ~ not-to-predict / to predict
- predicted_prompts *= 2
- predicted_answers *= 2
-
- self.problem.save_quizzes(
- result_dir,
- filename_prefix,
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
- predicted_prompts,
- predicted_answers,
- )
-
- def batches(self, split="train", desc=None):
- assert split in {"train", "test"}
- if split == "train":
- w_quizzes = self.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = self.test_w_quizzes
- c_quizzes = self.test_c_quizzes
-
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
-
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
-
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
-
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
-
- # Shuffle
- input = input[torch.randperm(input.size(0))]
-
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_token_values
-
- def produce_results(
- self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
- ):
- def compute_accuracy(input, log_prefix=None):
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
-
- n_forward = input[:, 0] == self.token_forward
- n_backward = input[:, 0] == self.token_backward
-
- correct[n_forward] = (
- (input[n_forward] == result[n_forward]).long().min(dim=1).values
- )
-
- if self.back_accuracy and n_backward.any():
- # accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.reverse_time(result[n_backward])
- back_input[:, 2 + self.prompt_len :] = input[
- n_backward, 1 : 1 + self.answer_len
- ]
- _, correct[n_backward] = compute_accuracy(back_input)
-
- if log_prefix is not None:
- forward_nb_correct = correct[n_forward].sum()
- forward_nb_total = correct[n_forward].size(0)
- backward_nb_correct = correct[n_backward].sum()
- backward_nb_total = correct[n_backward].size(0)
-
- self.logger(
- f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
- )
-
- self.logger(
- f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
- )
-
- return result, correct
-
- compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
-
- test_result, test_correct = compute_accuracy(
- self.test_w_quizzes[:nmax], log_prefix="test"
- )
-
- main_test_accuracy = test_correct.sum() / test_correct.size(0)
- self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
- ##############################
-
- self.save_quizzes(
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=test_result[:72],
- mistakes=test_correct[:72] * 2 - 1,
- )
-
- return main_test_accuracy
-
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
- nb = min(nb, input.size(0))
- input[:-nb] = input[nb:].clone()
- fresh_w_quizzes = self.generate_token_sequences(nb)
- self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to(self.device)
-
- def store_c_quizzes(self, new_c_quizzes, for_train=True):
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
- else:
- self.test_c_quizzes.append(new_c_quizzes)
-
- def compute_correctness(
- self,
- c_quizzes,
- models_for_validation,
- bidirectional_validation=False,
- deterministic_validation=True,
- ):
- if bidirectional_validation:
- backward_c_quizzes = self.forward_to_backward(c_quizzes)
-
- seq_logproba = torch.zeros(
- c_quizzes.size(0),
- max([m.id for m in models_for_validation]) + 1,
- device=self.device,
- )
-
- nb_correct = 0
-
- seq_logproba[...] = 0.0
-
- for model in models_for_validation:
- result = c_quizzes.clone()
-
- ar_mask = self.make_ar_mask(result)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
- deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving c_quizzes",
- device=self.device,
- )
-
- correct = (c_quizzes == result).long().min(dim=-1).values
-
- if bidirectional_validation:
- backward_result = backward_c_quizzes.clone()
-
- ar_mask = self.make_ar_mask(backward_result)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=backward_result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
- deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving backward c_quizzes",
- device=self.device,
- )
-
- backward_correct = (
- (backward_c_quizzes == backward_result).long().min(dim=-1).values
- )
-
- correct *= backward_correct
-
- # endif
-
- nb_correct += correct
-
- return nb_correct, seq_logproba
-
- ###############################################################
-
- def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
- c_quizzes = torch.empty(
- nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
- )
-
- seq_logproba = torch.zeros(nb, device=self.device)
-
- # First, we generate the answer at high temperature
-
- c_quizzes[:, 0] = self.token_backward
- c_quizzes[:, 1 + self.answer_len] = self.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, first=True),
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- # Then, we generate the prompt at low temperature
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
- temperature=1 / temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- # Then we return the quizz, and re-generate the response, now
- # at low temperature
-
- c_quizzes = self.reverse_time(c_quizzes)
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
- temperature=1 / temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- return c_quizzes
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Sky(problem.Problem):
- colors = torch.tensor(
- [
- [255, 255, 255],
- [255, 0, 0],
- [0, 192, 0],
- [0, 0, 255],
- [255, 192, 0],
- [0, 255, 255],
- [255, 0, 255],
- [192, 255, 192],
- [255, 192, 192],
- [192, 192, 255],
- [192, 192, 192],
- ]
- )
-
- token_background = 0
- first_bird_token = 1
- nb_bird_tokens = colors.size(0) - 1
-
- token2char = (
- "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
- )
-
- def __init__(
- self,
- height=6,
- width=8,
- nb_birds=3,
- speed=2,
- nb_iterations=2,
- avoid_collision=True,
- ):
- self.height = height
- self.width = width
- self.nb_birds = nb_birds
- self.speed = speed
- self.nb_iterations = nb_iterations
- self.avoid_collision = avoid_collision
-
- def generate_frame_sequences(self, nb):
- frame_sequences = []
-
- for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
- i, j, vi, vj = (
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- )
-
- def collision_okay():
- if not self.avoid_collision:
- return True
-
- count = torch.zeros(self.height, self.width, dtype=torch.int64)
-
- for n in range(self.nb_birds):
- count[i[n], j[n]] += 1
- count[i[n] - vi[n], j[n]] += 1
- count[i[n], j[n] - vj[n]] += 1
-
- return count.max() <= 1
-
- col = (
- torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
- + 1
- )
-
- while True:
- while True:
- for n in range(self.nb_birds):
- while True:
- i[n] = torch.randint(self.height, (1,))
- j[n] = torch.randint(self.width, (1,))
- vm = torch.randint(4, (1,))
- vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
- if (
- i[n] - vi[n] >= 0
- and i[n] - vi[n] < self.height
- and j[n] - vj[n] >= 0
- and j[n] - vj[n] < self.width
- ):
- break
-
- if collision_okay():
- break
-
- result = torch.zeros(
- self.nb_iterations * self.speed,
- self.height,
- self.width,
- dtype=torch.int64,
- )
-
- fine = torch.empty(self.nb_iterations * self.speed)
-
- t_to_keep = (
- torch.arange(self.nb_iterations, device=result.device) * self.speed
- )
-
- for l in range(self.nb_iterations * self.speed):
- fine[l] = collision_okay()
- for n in range(self.nb_birds):
- c = col[n]
- result[l, i[n], j[n]] = c
- result[l, i[n] - vi[n], j[n]] = c
- result[l, i[n], j[n] - vj[n]] = c
-
- if (i[n] == 0 and vi[n] == -1) or (
- i[n] == self.height - 1 and vi[n] == 1
- ):
- vi[n] = -vi[n]
-
- if (j[n] == 0 and vj[n] == -1) or (
- j[n] == self.width - 1 and vj[n] == 1
- ):
- vj[n] = -vj[n]
-
- i[n] += vi[n]
- j[n] += vj[n]
-
- result = result[t_to_keep]
- fine = fine[t_to_keep]
-
- if fine[-1]:
- break
-
- frame_sequences.append(result)
-
- return frame_sequences
-
- ######################################################################
-
- def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
- m = torch.logical_and(
- x >= 0, x < self.first_bird_token + self.nb_bird_tokens
- ).long()
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
-
- for n in range(m.size(0)):
- for i in range(m.size(1)):
- for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
-
- return x
-
- def seq2str(self, seq):
- result = []
- for s in seq:
- result.append("".join([self.token2char[v] for v in s]))
- return result
-
- def save_image(
- self,
- result_dir,
- filename,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- if predicted_prompts is None:
- predicted_prompts = 255
-
- if predicted_answers is None:
- predicted_answers = 255
-
- def add_frame(x, c, margin, bottom=False):
- if bottom:
- h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
- else:
- h, w, di, dj = (
- x.size(2) + 2 * margin,
- x.size(3) + 2 * margin,
- margin,
- margin,
- )
-
- y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
- if type(c) is int:
- y[...] = c
- else:
- c = c.long()[:, None]
- c = (
- (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
- + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
- + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
- )
- y[...] = c[:, :, None, None]
-
- y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
- return y
-
- margin = 4
-
- img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
- h = img_prompts.size(2)
- img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
- img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
- img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
- img_prompts = add_frame(
- img_prompts, c=predicted_prompts, margin=margin, bottom=True
- )
- img_answers = add_frame(
- img_answers, c=predicted_answers, margin=margin, bottom=True
- )
-
- marker_size = 16
-
- separator = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- marker_size,
- ),
- 255,
- )
-
- separator[:, :, 0] = 0
- separator[:, :, h - 1] = 0
-
- for k in range(1, 2 * marker_size - 8):
- i = k - (marker_size - 4)
- j = marker_size - 5 - abs(i)
- separator[:, :, h // 2 - 1 + i, 2 + j] = 0
- separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
-
- img = torch.cat([img_prompts, separator, img_answers], dim=3)
-
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(
- img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
- )
-
- ######################################################################
-
- def nb_token_values(self):
- return len(self.colors)
-
- def generate_prompts_and_answers(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
- frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-
- prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-
- answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-
- # warnings.warn("dirty test with longer answer", RuntimeWarning)
- # answers = torch.cat(
- # [
- # frame_sequences[:, frame_sequences.size(1) // 2 :],
- # frame_sequences[:, frame_sequences.size(1) // 2 :],
- # ],
- # dim=3,
- # ).flatten(1)
-
- return prompts, answers
-
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- self.save_image(
- result_dir,
- filename_prefix + ".png",
- prompts,
- answers,
- predicted_prompts,
- predicted_answers,
- )
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
-
- prompts, answers = sky.generate_prompts_and_answers(4)
-
- predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
- predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
-
- sky.save_quizzes(
- "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
- )
-
- # start_time = time.perf_counter()
- # token_sequences = sky.generate_token_sequences(nb=64)
- # delay = time.perf_counter() - start_time
- # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
- # print(sky.seq2str(seq[:4]))
-
- # for t in range(len(it[0])):
- # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
- # torchvision.utils.save_image(
- # img.float() / 255.0,
- # f"/tmp/frame_{t:03d}.png",
- # nrow=8,
- # padding=6,
- # pad_value=0,
- # )
-
- # m = (torch.rand(seq.size()) < 0.05).long()
- # seq = (1 - m) * seq + m * 23
-
- # print(seq.size())
- # img = sky.seq2img(token_sequences)
- # print(img.size())
-
- # torchvision.utils.save_image(
- # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
- # )
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-from mygpt import BracketedSequence
-
-######################################################################
-
-
-def masked_inplace_autoregression(
- model,
- batch_size,
- input,
- ar_mask,
- summed_logits,
- temperature,
- deterministic_synthesis,
- forbidden_tokens=None,
- logit_biases=None,
- progress_bar_desc="autoregression",
- device=torch.device("cpu"),
-):
- assert input.size() == ar_mask.size()
-
- batches = zip(input.split(batch_size), ar_mask.split(batch_size))
-
- if progress_bar_desc is not None:
- batches = tqdm.tqdm(
- batches,
- dynamic_ncols=True,
- desc=progress_bar_desc,
- total=(input.size(0) + batch_size - 1) // batch_size,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, ar_mask in batches:
- model.masked_inplace_autoregression(
- input=input,
- ar_mask=ar_mask,
- summed_logits=summed_logits,
- temperature=temperature,
- deterministic_synthesis=deterministic_synthesis,
- forbidden_tokens=forbidden_tokens,
- forced_biases=logit_biases,
- )
-
- model.train(t)
-
-
-######################################################################
-
-
-class Task:
- def batches(self, split="train", nb_to_use=-1, desc=None):
- pass
-
- def vocabulary_size(self):
- pass
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- pass
-
-
-######################################################################
-
-import world
-
-
-class World(Task):
- def save_image(self, input, result_dir, filename, logger):
- img = world.seq2img(input.to("cpu"), self.height, self.width)
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
- logger(f"wrote {image_name}")
-
- def make_ar_mask(self, input):
- b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
- return b.long()[None, :].expand_as(input)
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- result_dir=None,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
- self.height = 6
- self.width = 8
-
- self.train_input = world.generate_seq(
- nb_train_samples, height=self.height, width=self.width
- ).to(device)
-
- self.test_input = world.generate_seq(
- nb_test_samples, height=self.height, width=self.width
- ).to(device)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- self.train_quizzes = []
- self.test_quizzes = []
-
- if result_dir is not None:
- self.save_image(
- self.train_input[:72], result_dir, f"world_train.png", logger
- )
-
- def batches(self, split="train", desc=None):
- assert split in {"train", "test"}
- if split == "train":
- input = self.train_input
- quizzes = self.train_quizzes
- else:
- input = self.test_input
- quizzes = self.test_quizzes
-
- if len(quizzes) > 0:
- quizzes = torch.cat(quizzes, dim=0)
- if quizzes.size(0) > input.size(0) // 2:
- i = torch.randperm(input.size(0))[: input.size(0) // 2]
- quizzes = quizzes[i]
-
- i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
- input = input[i]
-
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = quizzes.size(0)
-
- input = torch.cat([input, quizzes], dim=0)
- else:
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = 0
-
- # Shuffle
- input = input[torch.randperm(input.size(0))]
-
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
- ):
- def compute_accuracy(input, logger=None):
- input = input[:nmax]
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- summed_logits=None,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- nb_total, nb_correct = (
- input.size(0),
- (input == result).long().min(dim=1).values.sum(),
- )
-
- return nb_total, nb_correct
-
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
-
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- main_test_accuracy = test_nb_correct / test_nb_total
- logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
- ##############################
-
- input = self.test_input[:96]
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- summed_logits=None,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- self.save_image(
- result[:72],
- result_dir,
- f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
- logger,
- )
-
- return main_test_accuracy
-
- def renew_samples(self, nb, for_train=True):
- input = self.train_input if for_train else self.test_input
- nb = min(nb, input.size(0))
- input[:-nb] = input[nb:].clone()
- input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
- self.device
- )
-
- def store_new_quizzes(self, new_quizzes, for_train=True):
- if for_train:
- self.train_quizzes.append(new_quizzes)
- else:
- self.test_quizzes.append(new_quizzes)
-
- def create_new_quizzes(
- self,
- n_epoch,
- result_dir,
- logger,
- nb,
- model,
- other_models,
- desired_average_logits=None,
- ):
- ###############################################################
- # Generate quizzes with model
-
- quizzes = torch.empty(
- nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
- )
-
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- summed_logits = torch.empty(nb, device=self.device)
-
- temperature = 1
- d_temperature = 1
-
- while True:
- summed_logits[...] = 0
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- summed_logits=summed_logits,
- temperature=temperature,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
-
- average_logits = summed_logits.mean()
-
- logger(f"{average_logits=} {desired_average_logits=}")
-
- if desired_average_logits is None:
- break
-
- # Oh man that's ugly
- if average_logits < desired_average_logits * 1.1:
- if d_temperature > 0:
- d_temperature *= -0.5
- temperature += d_temperature
- elif average_logits > desired_average_logits:
- if d_temperature < 0:
- d_temperature *= -0.5
- temperature += d_temperature
- else:
- break
-
- logger(f"changing temperature to {temperature}")
-
- ###############################################################
- # Create the reverse quizzes
-
- l = self.height * self.width
- direction = quizzes[:, l : l + 1]
- direction = world.token_forward * (
- direction == world.token_backward
- ) + world.token_backward * (direction == world.token_forward)
- reverse_quizzes = torch.cat(
- [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
- )
-
- ar_mask = self.make_ar_mask(quizzes)
-
- ###############################################################
- # Check how many of the other models can solve them in both
- # directions
-
- nb_correct = []
-
- for m in other_models:
- result = quizzes.clone()
-
- masked_inplace_autoregression(
- model=m,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- summed_logits=None,
- temperature=1.0,
- deterministic_synthesis=True,
- progress_bar_desc="solving quizzes",
- device=self.device,
- )
-
- correct = (quizzes == result).long().min(dim=-1).values
-
- reverse_result = reverse_quizzes.clone()
-
- masked_inplace_autoregression(
- model=m,
- batch_size=self.batch_size,
- input=reverse_result,
- ar_mask=ar_mask,
- summed_logits=None,
- temperature=1.0,
- deterministic_synthesis=True,
- progress_bar_desc="solving reversed quizzes",
- device=self.device,
- )
-
- reverse_correct = (
- (reverse_quizzes == reverse_result).long().min(dim=-1).values
- )
-
- nb_correct.append((correct * reverse_correct)[None, :])
-
- nb_correct = torch.cat(nb_correct, dim=0)
-
- # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- # with open(filename, "w") as f:
- # for k in nb_correct:
- # f.write(f"{k}\n")
-
- return quizzes, nb_correct.sum(dim=0), summed_logits.mean()
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Wireworld(problem.Problem):
- colors = torch.tensor(
- [
- [128, 128, 128],
- [128, 128, 255],
- [255, 0, 0],
- [255, 255, 0],
- ]
- )
-
- token_empty = 0
- token_head = 1
- token_tail = 2
- token_conductor = 3
- token_forward = 4
- token_backward = 5
-
- token2char = (
- "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
- )
-
- def __init__(
- self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
- ):
- self.height = height
- self.width = width
- self.nb_objects = nb_objects
- self.nb_walls = nb_walls
- self.speed = speed
- self.nb_iterations = nb_iterations
-
- def direction_tokens(self):
- return self.token_forward, self.token_backward
-
- def generate_frame_sequences(self, nb):
- result = []
- N = 100
- for _ in tqdm.tqdm(
- range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
- ):
- result.append(self.generate_frame_sequences_hard(100))
- return torch.cat(result, dim=0)[:nb]
-
- def generate_frame_sequences_hard(self, nb):
- frame_sequences = []
- nb_frames = (self.nb_iterations - 1) * self.speed + 1
-
- result = torch.full(
- (nb * 4, nb_frames, self.height, self.width),
- self.token_empty,
- )
-
- for n in range(result.size(0)):
- while True:
- i = torch.randint(self.height, (1,))
- j = torch.randint(self.width, (1,))
- v = torch.randint(2, (2,))
- vi = v[0] * (v[1] * 2 - 1)
- vj = (1 - v[0]) * (v[1] * 2 - 1)
- while True:
- if i < 0 or i >= self.height or j < 0 or j >= self.width:
- break
- o = 0
- if i > 0:
- o += (result[n, 0, i - 1, j] == self.token_conductor).long()
- if i < self.height - 1:
- o += (result[n, 0, i + 1, j] == self.token_conductor).long()
- if j > 0:
- o += (result[n, 0, i, j - 1] == self.token_conductor).long()
- if j < self.width - 1:
- o += (result[n, 0, i, j + 1] == self.token_conductor).long()
- if o > 1:
- break
- result[n, 0, i, j] = self.token_conductor
- i += vi
- j += vj
- if (
- result[n, 0] == self.token_conductor
- ).long().sum() > self.width and torch.rand(1) < 0.5:
- break
-
- while True:
- for _ in range(self.height * self.width):
- i = torch.randint(self.height, (1,))
- j = torch.randint(self.width, (1,))
- v = torch.randint(2, (2,))
- vi = v[0] * (v[1] * 2 - 1)
- vj = (1 - v[0]) * (v[1] * 2 - 1)
- if (
- i + vi >= 0
- and i + vi < self.height
- and j + vj >= 0
- and j + vj < self.width
- and result[n, 0, i, j] == self.token_conductor
- and result[n, 0, i + vi, j + vj] == self.token_conductor
- ):
- result[n, 0, i, j] = self.token_head
- result[n, 0, i + vi, j + vj] = self.token_tail
- break
-
- # if torch.rand(1) < 0.75:
- break
-
- weight = torch.full((1, 1, 3, 3), 1.0)
-
- mask = (torch.rand(result[:, 0].size()) < 0.01).long()
- rand = torch.randint(4, mask.size())
- result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
-
- # empty->empty
- # head->tail
- # tail->conductor
- # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
-
- nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
- valid = nb_heads > 0
-
- for l in range(nb_frames - 1):
- nb_head_neighbors = (
- F.conv2d(
- input=(result[:, l] == self.token_head).float()[:, None, :, :],
- weight=weight,
- padding=1,
- )
- .long()
- .squeeze(1)
- )
- mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
- nb_head_neighbors == 2
- ).long()
- result[:, l + 1] = (
- (result[:, l] == self.token_empty).long() * self.token_empty
- + (result[:, l] == self.token_head).long() * self.token_tail
- + (result[:, l] == self.token_tail).long() * self.token_conductor
- + (result[:, l] == self.token_conductor).long()
- * (
- mask_1_or_2_heads * self.token_head
- + (1 - mask_1_or_2_heads) * self.token_conductor
- )
- )
- pred_nb_heads = nb_heads
- nb_heads = (
- (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
- )
- valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
-
- result = result[valid]
-
- result = result[
- :, torch.arange(self.nb_iterations, device=result.device) * self.speed
- ]
-
- i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
- result = result[i]
-
- # print(f"{result.size(0)=} {nb=}")
-
- if result.size(0) < nb:
- # print(result.size(0))
- result = torch.cat(
- [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
- )
-
- return result[:nb]
-
- def generate_token_sequences(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
-
- result = []
-
- for frame_sequence in frame_sequences:
- a = []
- if torch.rand(1) < 0.5:
- for frame in frame_sequence:
- if len(a) > 0:
- a.append(torch.tensor([self.token_forward]))
- a.append(frame.flatten())
- else:
- for frame in reversed(frame_sequence):
- if len(a) > 0:
- a.append(torch.tensor([self.token_backward]))
- a.append(frame.flatten())
-
- result.append(torch.cat(a, dim=0)[None, :])
-
- return torch.cat(result, dim=0)
-
- ######################################################################
-
- def frame2img(self, x, scale=15):
- x = x.reshape(-1, self.height, self.width)
- m = torch.logical_and(x >= 0, x < 4).long()
-
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
-
- for n in range(m.size(0)):
- for i in range(m.size(1)):
- for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
-
- return x
-
- def seq2img(self, seq, scale=15):
- all = [
- self.frame2img(
- seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
- scale,
- )
- ]
-
- separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
-
- t = self.height * self.width
-
- while t < seq.size(1):
- direction_tokens = seq[:, t]
- t += 1
-
- direction_images = self.colors[
- torch.full(
- (direction_tokens.size(0), self.height * scale - 1, scale), 0
- )
- ].permute(0, 3, 1, 2)
-
- for n in range(direction_tokens.size(0)):
- if direction_tokens[n] == self.token_forward:
- for k in range(scale):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- 3 + scale // 2 - abs(k - scale // 2),
- ] = 0
- elif direction_tokens[n] == self.token_backward:
- for k in range(scale):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- 3 + abs(k - scale // 2),
- ] = 0
- else:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- k,
- ] = 0
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- scale - 1 - k,
- ] = 0
-
- all += [
- separator,
- direction_images,
- separator,
- self.frame2img(
- seq[:, t : t + self.height * self.width].reshape(
- -1, self.height, self.width
- ),
- scale,
- ),
- ]
-
- t += self.height * self.width
-
- return torch.cat(all, dim=3)
-
- def seq2str(self, seq):
- result = []
- for s in seq:
- result.append("".join([self.token2char[v] for v in s]))
- return result
-
- def save_image(self, input, result_dir, filename):
- img = self.seq2img(input.to("cpu"))
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-
- def save_quizzes(self, input, result_dir, filename_prefix):
- self.save_image(input, result_dir, filename_prefix + ".png")
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
-
- start_time = time.perf_counter()
- frame_sequences = wireworld.generate_frame_sequences(nb=96)
- delay = time.perf_counter() - start_time
- print(f"{frame_sequences.size(0)/delay:02f} seq/s")
-
- # print(wireworld.seq2str(seq[:4]))
-
- for t in range(frame_sequences.size(1)):
- img = wireworld.seq2img(frame_sequences[:, t])
- torchvision.utils.save_image(
- img.float() / 255.0,
- f"/tmp/frame_{t:03d}.png",
- nrow=8,
- padding=6,
- pad_value=0,
- )
-
- # m = (torch.rand(seq.size()) < 0.05).long()
- # seq = (1 - m) * seq + m * 23
-
- wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
- token_sequences = wireworld.generate_token_sequences(32)
- wireworld.save_quizzes(token_sequences, "/tmp", "seq")
- # img = wireworld.seq2img(frame_sequences[:60])
-
- # torchvision.utils.save_image(
- # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
- # )