--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ def forward(self, x):
+ t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+ j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+ k = j % 2 # works with float, weird
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+ y = x + pe
+ return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, x):
+ return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ attention=vanilla_attention,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.attention = attention
+ self.attention_dropout = attention_dropout
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(nb_heads, dim_v, dim_model)
+
+ def forward(self, x_q, x_kv=None):
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+ v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+ y = self.attention(q, k, v)
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+ def __init__(self, masker, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.masker = masker
+ self.mask = None
+
+ def forward(self, x):
+ if self.mask is None:
+ self.mask = self.masker(x)
+ return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ nb_work_tokens=100,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.nb_work_tokens = nb_work_tokens
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ def no_peek_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ n = self.nb_work_tokens
+ s = (q.size(2) - n) // 2
+ a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+ a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+ def masker(x):
+ m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+ return m[None, :, None]
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=no_peek_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+if __name__ == "__main__":
+ model = FunctionalAttentionAE(
+ vocabulary_size=100,
+ dim_model=16,
+ dim_keys=64,
+ dim_hidden=32,
+ nb_heads=4,
+ nb_work_tokens=10,
+ nb_blocks=4,
+ dropout=0.1,
+ )
+
+ x = torch.randint(100, (10, 50))
+ y = model(x)
+
+ with torch.no_grad():
+ model.eval()
+ x = torch.randint(100, (10, 50))
+ y = model(x)
+
+ print(y)
# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo, re
import torch, torchvision
######################################################################
+
+def text_img(height, width, text):
+ pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+ surface = cairo.ImageSurface.create_for_data(
+ pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+ )
+
+ ctx = cairo.Context(surface)
+ ctx.set_source_rgb(0, 0, 0)
+ ctx.set_font_size(16)
+ ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+ y = None
+ for line in text.split("\n"):
+ xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+ if y is None:
+ y = height * 1.5
+ x = height * 0.5
+
+ ctx.move_to(x, y)
+ ctx.show_text(line)
+ y += height * 1.5
+
+ ctx.stroke()
+
+ return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
import problem
class Grids(problem.Problem):
+ # grid_gray = 64
+ # thickness = 1
+ # background_gray = 255
+ # dots = False
+
+ grid_gray = 240
+ thickness = 0
+ background_gray = 240
+ dots = False
+
+ # grid_gray = 192
+ # thickness = 0
+ # background_gray = 255
+ # dots = True
+
named_colors = [
- ("white", [255, 255, 255]),
+ ("white", [background_gray, background_gray, background_gray]),
+ # ("white", [224, 224, 224]),
("red", [255, 0, 0]),
- ("green", [0, 192, 0]),
+ ("green", [0, 160, 0]),
("blue", [0, 0, 255]),
("yellow", [255, 224, 0]),
("cyan", [0, 255, 255]),
("violet", [224, 128, 255]),
- ("lightgreen", [192, 255, 192]),
+ ("lightgreen", [160, 255, 160]),
("brown", [165, 42, 42]),
("lightblue", [192, 192, 255]),
("gray", [128, 128, 128]),
]
+ def pure_noise(self, nb, device):
+ result = torch.randint(
+ self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
+ )
+ return result
+
+ def trivial(self, quizzes):
+ S = self.height * self.width
+ assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
+ a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+ dim=1
+ ).values
+
+ def text2quiz(self, t):
+ chr2col = [
+ (".", "white"),
+ ("r", "red"),
+ ("g", "green"),
+ ("b", "blue"),
+ ("y", "yellow"),
+ ("c", "cyan"),
+ ("v", "violet"),
+ ("l", "lightgreen"),
+ ("o", "brown"),
+ ("l", "lightblue"),
+ ("a", "gray"),
+ ]
+
+ col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+ chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+ t = re.sub(r"#.*\n", "", t).strip()
+ l = t.replace("\n\n", ";").split(";")
+
+ result = []
+
+ for t in l:
+ t = "".join(t.replace("\n", " ").strip().split(" "))
+ t = torch.tensor([chr2tok[c] for c in t])
+ t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+ t = torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.token_A],
+ [self.token_f_A],
+ [self.token_B],
+ [self.token_f_B],
+ ]
+ ),
+ t,
+ ],
+ dim=1,
+ )
+ result.append(t.flatten()[None, :])
+
+ return torch.cat(result, dim=0)
+
+ def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+ return (
+ (q[:, 0, 0] == self.l2tok[quad_order[0]])
+ & (q[:, 1, 0] == self.l2tok[quad_order[1]])
+ & (q[:, 2, 0] == self.l2tok[quad_order[2]])
+ & (q[:, 3, 0] == self.l2tok[quad_order[3]])
+ )
+
def __init__(
self,
max_nb_cached_chunks=None,
tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
+
+ self.nb_colors = len(self.colors)
+
+ self.nb_rec_max = 5
+ self.rfree = torch.tensor([])
+
self.height = 10
self.width = 10
+ self.seq_len = 4 * self.height * self.width
+
self.cache_rec_coo = {}
all_tasks = [
+ ############################################ fundamental ones
self.task_replace_color,
self.task_translate,
self.task_grow,
- self.task_half_fill,
self.task_frame,
+ ############################################
+ ############################################
+ self.task_half_fill,
self.task_detect,
- self.task_count,
- self.task_trajectory,
- self.task_bounce,
self.task_scale,
self.task_symbols,
+ self.task_corners,
+ self.task_contact,
+ self.task_path,
+ self.task_fill,
+ ############################################ hard ones
self.task_isometry,
- # self.task_islands,
+ self.task_trajectory,
+ self.task_bounce,
+ # self.task_count, # NOT REVERSIBLE
+ # self.task_islands, # TOO MESSY
]
if tasks is None:
######################################################################
- def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
- m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
+ def vocabulary_size(self):
+ # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+ # return self.nb_colors+4
+ return self.nb_colors
+
+ def grid2img(self, x, scale=15, grids=True):
+ m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+ y = self.colors[x * m].permute(0, 3, 1, 2)
+ s = y.shape
+ y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+ if grids:
+ for t in range(self.thickness):
+ y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+ y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+ if self.dots:
+ z = y.reshape(
+ y.size(0),
+ y.size(1),
+ y.size(2) // scale,
+ scale,
+ y.size(3) // scale,
+ scale,
+ )
+ z = z[
+ :,
+ :,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ ]
+ zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+ z[...] = zz * self.grid_gray + (zz == False) * z
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
+ if x[n, i, j] >= self.nb_colors:
+ # for k in range(3, scale - 2):
+ c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+ # y[n, :, i * scale + k, j * scale + k] = c
+ # y[n, :, i * scale + k, j * scale + scale - k] = c
+ y[
+ n,
+ :,
+ i * scale + 3 : i * scale + scale - 2,
+ j * scale + 3 : j * scale + scale - 2,
+ ] = c
+
+ y = y[:, :, 1:, 1:]
+
+ return y
+
+ def add_frame(self, img, colors, thickness):
+ if thickness > 0:
+ result = img.new(
+ img.size(0),
+ img.size(1),
+ img.size(2) + 2 * thickness,
+ img.size(3) + 2 * thickness,
+ )
- return x
+ result[...] = colors[:, :, None, None]
+ result[:, :, thickness:-thickness, thickness:-thickness] = img
+ else:
+ result = img
- def save_image(
+ return result
+
+ def save_quizzes_as_image(
self,
result_dir,
filename,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
+ quizzes,
+ predicted_parts=None,
+ correct_parts=None,
+ comments=None,
+ comment_height=48,
nrow=4,
- margin=8,
+ grids=True,
+ margin=12,
+ delta=False,
+ delta_highlight=False,
):
+ quizzes = quizzes.to("cpu")
+
S = self.height * self.width
- As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
- f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
- -1, self.height, self.width
- )
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
- prompts = torch.cat([As, f_As, Bs], dim=2)
- answers = answers.reshape(answers.size(0), self.height, self.width)
- if predicted_prompts is None:
- predicted_prompts = 255
+ A, f_A, B, f_B = (
+ quizzes.reshape(quizzes.size(0), 4, S)
+ .reshape(quizzes.size(0), 4, self.height, self.width)
+ .permute(1, 0, 2, 3)
+ )
- if predicted_answers is None:
- predicted_answers = 255
+ frame, white, gray, green, red = torch.tensor(
+ [
+ [self.grid_gray, self.grid_gray, self.grid_gray],
+ [255, 255, 255],
+ [200, 200, 200],
+ [0, 255, 0],
+ [255, 0, 0],
+ ],
+ device=quizzes.device,
+ )
- def add_frame(x, c, margin, bottom=False):
- if bottom:
- h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
- else:
- h, w, di, dj = (
- x.size(2) + 2 * margin,
- x.size(3) + 2 * margin,
- margin,
- margin,
- )
+ thickness = self.thickness
- y = x.new_full((x.size(0), x.size(1), h, w), 0)
+ if delta:
+ u = (A != f_A).long()
+ img_delta_A = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_A
+ )
+ u = (B != f_B).long()
+ img_delta_B = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_B
+ )
- if type(c) is int:
- y[...] = c
- else:
- c = c.long()[:, None]
- c = (
- (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
- * torch.tensor([64, 64, 64])
- + (c == 1).long() * torch.tensor([0, 255, 0])
- + (c == 0).long() * torch.tensor([255, 255, 255])
- + (c == -1).long() * torch.tensor([255, 0, 0])
- )
- y[...] = c[:, :, None, None]
+ img_A = self.add_frame(
+ self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_f_A = self.add_frame(
+ self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_B = self.add_frame(
+ self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_f_B = self.add_frame(
+ self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
+ )
- y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+ if delta_highlight:
+ q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+ img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
- return y
+ # predicted_parts Nx4
+ # correct_parts Nx4
- img_prompts = torch.cat(
- [
- add_frame(
- add_frame(self.frame2img(x), c=0, margin=1),
- c=predicted_prompts,
- margin=margin,
+ if predicted_parts is None:
+ colors = white[None, None, :].expand(-1, 4, -1)
+ else:
+ predicted_parts = predicted_parts.to("cpu")
+ if correct_parts is None:
+ colors = (
+ predicted_parts[:, :, None] * gray[None, None, :]
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+ )
+ else:
+ correct_parts = correct_parts.to("cpu")
+ colors = (
+ predicted_parts[:, :, None]
+ * (
+ (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+ + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+ + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+ )
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
)
- for x in prompts.to("cpu").split(split_size=self.width, dim=2)
- ],
- dim=3,
- )
- h = img_prompts.size(2)
- img_answers = add_frame(
- add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
- c=predicted_answers,
- margin=margin,
- )
+ separation = 6
- separator_size = 2 * margin
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
- separator = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
- )
+ img_A = self.add_frame(img_A, white[None, :], thickness=2)
+ img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+ img_B = self.add_frame(img_B, white[None, :], thickness=2)
+ img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
- marker = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
- )
-
- # marker[:, :, 0] = 0
- # marker[:, :, h - 1] = 0
-
- for k in range(1, 2 * separator_size - 8):
- i = k - (separator_size - 4)
- j = separator_size - 5 - abs(i)
- marker[:, :, h // 2 - 1 + i, 2 + j] = 0
- marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+ if delta:
+ img_delta_A = self.add_frame(
+ img_delta_A, colors[:, 0], thickness=separation
+ )
+ img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+ img_delta_B = self.add_frame(
+ img_delta_B, colors[:, 0], thickness=separation
+ )
+ img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+ img = torch.cat(
+ [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+ )
+ else:
+ img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
- img = torch.cat(
- [
- img_prompts,
- marker,
- img_answers,
- ],
- dim=3,
- )
+ if comments is not None:
+ comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+ comment_img = torch.cat(comment_img, dim=0)
+ img = torch.cat([img, comment_img], dim=2)
image_name = os.path.join(result_dir, filename)
+
torchvision.utils.save_image(
img.float() / 255.0,
image_name,
######################################################################
- def nb_token_values(self):
- return len(self.colors)
-
# @torch.compile
def rec_coo(
self,
while True:
i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
-
+ i[:, 1] += 1
+ j[:, 1] += 1
big_enough = (
(i[:, 1] >= i[:, 0] + min_height)
& (j[:, 1] >= j[:, 0] + min_height)
j[:, 1, 0],
j[:, 1, 1],
)
- no_overlap = torch.logical_not(
+ no_overlap = (
(A_i1 >= B_i2)
- & (A_i2 <= B_i1)
- & (A_j1 >= B_j1)
- & (A_j2 <= B_j1)
+ | (A_i2 <= B_i1)
+ | (A_j1 >= B_j2)
+ | (A_j2 <= B_j1)
)
- i, j = i[no_overlap], j[no_overlap]
+ i, j = (i[no_overlap], j[no_overlap])
elif nb_rec == 3:
A_i1, A_i2, A_j1, A_j2 = (
i[:, 0, 0],
######################################################################
+ def contact_matrices(self, rn, ri, rj, rz):
+ n = torch.arange(self.nb_rec_max)
+ return (
+ (
+ (
+ (
+ (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
+ | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
+ )
+ & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+ & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+ )
+ | (
+ (
+ (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
+ | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
+ )
+ & (ri[:, :, None, 0] <= ri[:, None, :, 1])
+ & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+ )
+ )
+ # & (rz[:, :, None] == rz[:, None, :])
+ & (n[None, :, None] < rn[:, None, None])
+ & (n[None, None, :] < n[None, :, None])
+ )
+
+ def sample_rworld_states(self, N=1000):
+ while True:
+ ri = (
+ torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ ri[:, :, 1] += 2
+ rj = (
+ torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ rj[:, :, 1] += 2
+ rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
+ rz = torch.randint(2, (N, self.nb_rec_max))
+ rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
+ n = torch.arange(self.nb_rec_max)
+ nb_collisions = (
+ (
+ (ri[:, :, None, 0] <= ri[:, None, :, 1])
+ & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+ & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+ & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+ & (rz[:, :, None] == rz[:, None, :])
+ & (n[None, :, None] < rn[:, None, None])
+ & (n[None, None, :] < n[None, :, None])
+ )
+ .long()
+ .flatten(1)
+ .sum(dim=1)
+ )
+
+ no_collision = nb_collisions == 0
+
+ if no_collision.any():
+ print(no_collision.long().sum() / N)
+ self.rn = rn[no_collision]
+ self.ri = ri[no_collision]
+ self.rj = rj[no_collision]
+ self.rz = rz[no_collision]
+ self.rc = rc[no_collision]
+
+ nb_contact = (
+ self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+ )
+
+ self.rcontact = nb_contact > 0
+ self.rfree = torch.full((self.rn.size(0),), True)
+
+ break
+
+ def get_recworld_state(self):
+ if not self.rfree.any():
+ self.sample_rworld_states()
+ k = torch.arange(self.rn.size(0))[self.rfree]
+ k = k[torch.randint(k.size(0), (1,))].item()
+ self.rfree[k] = False
+ return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
+
+ def draw_state(self, X, rn, ri, rj, rz, rc):
+ for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
+ X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
+
+ def task_recworld_immobile(self, A, f_A, B, f_B):
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ rn, ri, rj, rz, rc = self.get_recworld_state()
+ self.draw_state(X, rn, ri, rj, rz, rc)
+ ri += 1
+ self.draw_state(f_X, rn, ri, rj, rz, rc)
+
+ ######################################################################
+
# @torch.compile
def task_replace_color(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
X[i1:i2, j1:j2] = c[n]
f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+ # @torch.compile
+ def task_symmetry(self, A, f_A, B, f_B):
+ a, b = torch.randint(2, (2,))
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ if min([x[2] for x in r]) > self.height // 2 + 1:
+ break
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ X[: self.height // 2] = 0
+ f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2]
+ if a == 1:
+ X[...] = X.flip((0,))
+ f_X[...] = f_X.flip((0,))
+ if b == 1:
+ X[...] = X.clone().t()
+ f_X[...] = f_X.clone().t()
+
# @torch.compile
def task_translate(self, A, f_A, B, f_B):
while True:
break
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
r = self.rec_coo(nb_rec, prevent_overlap=True)
def task_grow(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
direction = torch.randint(2, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
def task_half_fill(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
direction = torch.randint(4, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
# @torch.compile
def task_frame(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
# @torch.compile
def task_detect(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
if n < nb_rec - 1:
- f_X[i1, j1] = c[-1]
+ for k in range(2):
+ f_X[i1 + k, j1] = c[-1]
+ f_X[i1, j1 + k] = c[-1]
# @torch.compile
def contact(self, X, i, j, q):
return no, nq, nq_diag
- def task_count(self, A, f_A, B, f_B):
+ def REMOVED_task_count(self, A, f_A, B, f_B):
while True:
error = False
- N = torch.randint(5, (1,)).item() + 1
- c = torch.zeros(N + 1)
- c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+ N = 3
+ c = torch.zeros(N + 2, dtype=torch.int64)
+ c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
self.height,
self.width,
nb_seeds=self.height * self.width // 8,
- nb_iterations=self.height * self.width // 10,
+ nb_iterations=self.height * self.width // 5,
)
)
X[...] = self.cache_count.pop()
- k = (X.max() + 1 + (c.size(0) - 1)).item()
- V = torch.arange(k) // (c.size(0) - 1)
- V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
- c.size(0) - 1
- ) + 1
+ # k = (X.max() + 1 + (c.size(0) - 1)).item()
+ # V = torch.arange(k) // (c.size(0) - 1)
+ # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+ # c.size(0) - 1
+ # ) + 1
+
+ V = torch.randint(N, (X.max() + 1,)) + 1
V[0] = 0
+ NB = F.one_hot(c[V]).sum(dim=0)
X[...] = c[V[X]]
-
- if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
- f_X[...] = 0
- for e in range(1, N + 1):
- for j in range((X == c[e]).sum() + 1):
- if j < self.width:
- f_X[e - 1, j] = c[e]
- else:
- error = True
- break
+ f_X[...] = X
+
+ if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+ m = NB[c[:-1]].max()
+ if (NB[c[:-1]] == m).long().sum() == 1:
+ for e in range(1, N + 1):
+ if NB[c[e]] == m:
+ a = (f_X == c[e]).long()
+ f_X[...] = (1 - a) * f_X + a * c[-1]
else:
error = True
break
if not error:
break
+ assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
+
# @torch.compile
def task_trajectory(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
di, dj = torch.randint(7, (2,)) - 3
# @torch.compile
def task_bounce(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
# @torch.compile
def free(i, j):
f_X[i, j] = c[2]
if l <= 1:
X[i, j] = c[2]
+ f_X[i, j] = c[1]
if l >= self.width:
break
# @torch.compile
def task_scale(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
i, j = (
torch.randint(self.height // 2, (1,)).item(),
X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
- X[i, j] = c[1]
- f_X[0:2, 0:2] = c[1]
+ for k in range(2):
+ X[i + k, j] = c[1]
+ X[i, j + k] = c[1]
+ f_X[i + k, j] = c[1]
+ f_X[i, j + k] = c[1]
# @torch.compile
def task_symbols(self, A, f_A, B, f_B):
nb_rec = 4
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
delta = 3
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
if d.min() > delta:
break
- for k in range(1, nb_rec):
- X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
ai, aj = i.float().mean(), j.float().mean()
q = torch.randint(3, (1,)).item() + 1
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
-
assert i[q] != ai and j[q] != aj
- X[
+ for Z in [X, f_X]:
+ for k in range(0, nb_rec):
+ Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+ # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+ # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+ # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+ # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+ # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+ f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+ # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+ ii, jj = (
i[0] + delta // 2 + (i[q] - ai).sign().long(),
j[0] + delta // 2 + (j[q] - aj).sign().long(),
- ] = c[nb_rec]
+ )
+
+ X[ii, jj] = c[nb_rec]
+ X[i[0] + delta // 2, jj] = c[nb_rec]
+ X[ii, j[0] + delta // 2] = c[nb_rec]
- f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+ f_X[ii, jj] = c[nb_rec]
+ f_X[i[0] + delta // 2, jj] = c[nb_rec]
+ f_X[ii, j[0] + delta // 2] = c[nb_rec]
# @torch.compile
def task_isometry(self, A, f_A, B, f_B):
X[...] = 0
f_X[...] = 0
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for r in range(nb_rec):
while True:
return dist * (1 - walls)
# @torch.compile
- def task_distance(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ def REMOVED_task_distance(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
dist0 = torch.empty(self.height + 2, self.width + 2)
dist1 = torch.empty(self.height + 2, self.width + 2)
for X, f_X in [(A, f_A), (B, f_B)]:
# if
# @torch.compile
- def task_puzzle(self, A, f_A, B, f_B):
+ def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
S = 4
i0, j0 = (self.height - S) // 2, (self.width - S) // 2
- c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
f_X[...] = 0
if f_X[i + i0, j + j0] == c[d]:
X[ii + i, jj + j] = c[d]
- def task_islands(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
self.cache_islands = list(
break
X[...] = (A > 0) * c[0]
- X[i, j] = c[1]
f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+ f_X[i, j] = X[i, j]
+ X[i, j] = c[1]
+
+ # @torch.compile
+ def TOO_HARD_task_stack(self, A, f_A, B, f_B):
+ N = 5
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ i1, j1, i2, j2 = (
+ self.height // 2 - 1,
+ self.width // 2 - 1,
+ self.height // 2 + 1,
+ self.width // 2 + 1,
+ )
+ op = torch.tensor((0, 1, 2, 3) * 4)
+ op = op[torch.randperm(op.size(0))[:9]]
+ for q in range(op.size(0)):
+ u = 3 * (q // 3)
+ v = 3 * (q % 3)
+ d = c[torch.randint(N, (1,)).item()]
+ # X[u+1,v+1]=d
+ if op[q] == 0: # right
+ X[u : u + 3, v + 2] = d
+ elif op[q] == 1: # let
+ X[u : u + 3, v] = d
+ elif op[q] == 2: # bottom
+ X[u + 2, v : v + 3] = d
+ elif op[q] == 3: # top
+ X[u, v : v + 3] = d
+
+ if q == 0:
+ f_X[i1:i2, j1:j2] = d
+ elif op[q] == 0: # right
+ f_X[i1:i2, j2] = d
+ j2 += 1
+ elif op[q] == 1: # let
+ j1 -= 1
+ f_X[i1:i2, j1] = d
+ elif op[q] == 2: # bottom
+ f_X[i2, j1:j2] = d
+ i2 += 1
+ elif op[q] == 3: # top
+ i1 -= 1
+ f_X[i1, j1:j2] = d
+
+ def randint(self, *m):
+ m = torch.tensor(m)
+ return (torch.rand(m.size()) * m).long()
+
+ def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
+ N = 6
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ M1 = torch.randint(2, (5, 5))
+ M2 = torch.randint(2, (5, 5))
+ P = M1 @ M2
+ for i in range(5):
+ for j in range(5):
+ X[i, j] = c[M1[i, j]]
+ X[i, j + 5] = c[M2[i, j]]
+ f_X[i, j] = c[M1[i, j]]
+ f_X[i, j + 5] = c[M2[i, j]]
+ f_X[i + 5, j + 5] = c[P[i, j]]
+
+ def TOO_HARD_task_compute(self, A, f_A, B, f_B):
+ N = 6
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ v = torch.randint((self.width - 1) // 2, (N,)) + 1
+ chain = torch.randperm(N)
+ eq = []
+ for i in range(chain.size(0) - 1):
+ i1, i2 = chain[i], chain[i + 1]
+ v1, v2 = v[i1], v[i2]
+ k = torch.arange(self.width // 2) + 1
+ d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+ d = d[torch.randint(d.size(0), (1,)).item()]
+ w1, w2 = d
+ eq.append((c[i1], w1, c[i2], w2))
+
+ ii = torch.randperm(self.height - 2)[: len(eq)]
+
+ for k, x in enumerate(eq):
+ i = ii[k]
+ c1, w1, c2, w2 = x
+ s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+ X[i, s : s + w1] = c1
+ X[i, s + w1 : s + w1 + w2] = c2
+ f_X[i, s : s + w1] = c1
+ f_X[i, s + w1 : s + w1 + w2] = c2
+
+ i1, i2 = torch.randperm(N)[:2]
+ v1, v2 = v[i1], v[i2]
+ k = torch.arange(self.width // 2) + 1
+ d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+ d = d[torch.randint(d.size(0), (1,)).item()]
+ w1, w2 = d
+ c1, c2 = c[i1], c[i2]
+ s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+ i = self.height - 1
+ X[i, s : s + w1] = c1
+ X[i, s + w1 : s + w1 + 1] = c2
+ f_X[i, s : s + w1] = c1
+ f_X[i, s + w1 : s + w1 + w2] = c2
+
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_contact(self, A, f_A, B, f_B):
+ def rec_dist(a, b):
+ ai1, aj1, ai2, aj2 = a
+ bi1, bj1, bi2, bj2 = b
+ v = max(ai1 - bi2, bi1 - ai2)
+ h = max(aj1 - bj2, bj1 - aj2)
+ return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+ if min(d[1:]) == 0:
+ break
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ if d[n] == 0:
+ f_X[i1, j1:j2] = c[0]
+ f_X[i2 - 1, j1:j2] = c[0]
+ f_X[i1:i2, j1] = c[0]
+ f_X[i1:i2, j2 - 1] = c[0]
+
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_corners(self, A, f_A, B, f_B):
+ polarity = torch.randint(2, (1,)).item()
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ for k in range(2):
+ if polarity == 0:
+ X[i1 + k, j1] = c[n]
+ X[i2 - 1 - k, j2 - 1] = c[n]
+ X[i1, j1 + k] = c[n]
+ X[i2 - 1, j2 - 1 - k] = c[n]
+ else:
+ X[i1 + k, j2 - 1] = c[n]
+ X[i2 - 1 - k, j1] = c[n]
+ X[i1, j2 - 1 - k] = c[n]
+ X[i2 - 1, j1 + k] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ def compdist(self, X, i, j):
+ dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+ d = dd[1:-1, 1:-1]
+ m = (X > 0).long()
+ d[i, j] = 0
+ e = d.clone()
+ while True:
+ e[...] = d
+ d[...] = (
+ d.min(dd[:-2, 1:-1] + 1)
+ .min(dd[2:, 1:-1] + 1)
+ .min(dd[1:-1, :-2] + 1)
+ .min(dd[1:-1, 2:] + 1)
+ )
+ d[...] = (1 - m) * d + m * self.height * self.width
+ if e.equal(d):
+ break
+
+ return d
+
+ # @torch.compile
+ def task_path(self, A, f_A, B, f_B):
+ nb_rec = 2
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ i1, i2 = torch.randint(self.height, (2,))
+ j1, j2 = torch.randint(self.width, (2,))
+ if (
+ abs(i1 - i2) + abs(j1 - j2) > 2
+ and X[i1, j1] == 0
+ and X[i2, j2] == 0
+ ):
+ d2 = self.compdist(X, i2, j2)
+ d = self.compdist(X, i1, j1)
+
+ if d2[i1, j1] < 2 * self.width:
+ break
+
+ m = ((d + d2) == d[i2, j2]).long()
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+
+ X[i1, j1] = c[-2]
+ X[i2, j2] = c[-2]
+ f_X[i1, j1] = c[-2]
+ f_X[i2, j2] = c[-2]
+
+ # @torch.compile
+ def task_fill(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ accept_full = torch.rand(1) < 0.5
+
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ while True:
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i, j] == 0:
+ break
+
+ d = self.compdist(X, i, j)
+ m = (d < self.height * self.width).long()
+ X[i, j] = c[-1]
+ f_X[...] = m * c[-1] + (1 - m) * f_X
+ f_X[i, j] = 0
+
+ if accept_full or (d * (X == 0)).max() == self.height * self.width:
+ break
+
+ def TOO_HARD_task_addition(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+ N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+ S = N1 + N2
+ for j in range(self.width):
+ r1 = (N1 // (2**j)) % 2
+ X[0, -j - 1] = c[r1]
+ f_X[0, -j - 1] = c[r1]
+ r2 = (N2 // (2**j)) % 2
+ X[1, -j - 1] = c[r2]
+ f_X[1, -j - 1] = c[r2]
+ rs = (S // (2**j)) % 2
+ f_X[2, -j - 1] = c[2 + rs]
+
+ def task_science_implicit(self, A, f_A, B, f_B):
+ nb_rec = 5
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ i1, i2 = torch.randint(self.height, (2,)).sort().values
+ if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+ break
+
+ while True:
+ j1, j2 = torch.randint(self.width, (2,)).sort().values
+ if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+ break
+
+ f_X[i1:i2, j1:j2] = c[0]
+
+ # ---------------------
+
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(j1, (1,))
+ X[ii1:ii2, jj:j1] = c[1]
+ f_X[ii1:ii2, jj:j1] = c[1]
+
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+ X[ii1:ii2, j2:jj] = c[2]
+ f_X[ii1:ii2, j2:jj] = c[2]
+
+ # ---------------------
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(i1, (1,))
+ X[ii:i1, jj1:jj2] = c[3]
+ f_X[ii:i1, jj1:jj2] = c[3]
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+ X[i2:ii, jj1:jj2] = c[4]
+ f_X[i2:ii, jj1:jj2] = c[4]
+
+ def task_science_dot(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ q = 0
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ if i >= i1 and i < i2:
+ q += 1
+ f_X[i, j1:j2] = c[-1]
+ if j >= j1 and j < j2:
+ q += 1
+ f_X[i1:i2, j] = c[-1]
+ X[i, j] = c[-1]
+ f_X[i, j] = c[-1]
+ if q >= 2:
+ break
+
+ def collide(self, s, r, rs):
+ i, j = r
+ for i2, j2 in rs:
+ if abs(i - i2) < s and abs(j - j2) < s:
+ return True
+ return False
+
+ def task_science_tag(self, A, f_A, B, f_B):
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ rs = []
+ while len(rs) < 4:
+ i, j = (
+ torch.randint(self.height - 3, (1,)).item(),
+ torch.randint(self.width - 3, (1,)).item(),
+ )
+ if not self.collide(s=3, r=(i, j), rs=rs):
+ rs.append((i, j))
+
+ for k in range(len(rs)):
+ i, j = rs[k]
+ q = min(k, 2)
+ X[i, j : j + 3] = c[q]
+ X[i + 2, j : j + 3] = c[q]
+ X[i : i + 3, j] = c[q]
+ X[i : i + 3, j + 2] = c[q]
+
+ f_X[i, j : j + 3] = c[q]
+ f_X[i + 2, j : j + 3] = c[q]
+ f_X[i : i + 3, j] = c[q]
+ f_X[i : i + 3, j + 2] = c[q]
+ if q == 2:
+ f_X[i + 1, j + 1] = c[-1]
+
+ # end_tasks
######################################################################
- def trivial_prompts_and_answers(self, prompts, answers):
+ def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
S = self.height * self.width
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
- f_Bs = answers
- return (Bs == f_Bs).long().min(dim=-1).values > 0
+ quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+ quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+ quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+ quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+ quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
- def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
- if tasks is None:
- tasks = self.all_tasks
+ return quizzes
+ def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
S = self.height * self.width
- prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
- answers = torch.zeros(nb, S, dtype=torch.int64)
- bunch = zip(prompts, answers)
+ if tasks is None:
+ tasks = self.all_tasks
+
+ quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
if progress_bar:
- bunch = tqdm.tqdm(
- bunch,
+ quizzes = tqdm.tqdm(
+ quizzes,
dynamic_ncols=True,
- desc="world generation",
- total=prompts.size(0),
+ desc="world quizzes generation",
+ total=quizzes.size(0),
)
- for prompt, answer in bunch:
- A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
- f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
- B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
- f_B = answer.view(self.height, self.width)
+ for quiz in quizzes:
+ q = quiz.reshape(4, self.height, self.width)
+ q[...] = 0
+ A, f_A, B, f_B = q
task = tasks[torch.randint(len(tasks), (1,)).item()]
task(A, f_A, B, f_B)
- return prompts.flatten(1), answers.flatten(1)
+ return quizzes
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- nrow=4,
- ):
- self.save_image(
- result_dir,
- filename_prefix + ".png",
- prompts,
- answers,
- predicted_prompts,
- predicted_answers,
- nrow,
- )
-
- def save_some_examples(self, result_dir):
- nb, nrow = 72, 4
+ def save_some_examples(self, result_dir, prefix=""):
+ nb, nrow = 256, 8
for t in self.all_tasks:
print(t.__name__)
- prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
- self.save_quiz_illustrations(
- result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+ self.save_quizzes_as_image(
+ result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
)
import time
# grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
+
grids = Grids()
- # nb = 1000
- # grids = problem.MultiThreadProblem(
- # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
- # )
- # time.sleep(10)
- # start_time = time.perf_counter()
- # prompts, answers = grids.generate_prompts_and_answers(nb)
- # delay = time.perf_counter() - start_time
- # print(f"{prompts.size(0)/delay:02f} seq/s")
- # exit(0)
-
- # if True:
- nb, nrow = 72, 4
+ nb, nrow = 64, 4
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_distance]:
+
+ for t in [
+ grids.task_replace_color,
+ grids.task_translate,
+ grids.task_grow,
+ grids.task_frame,
+ ]:
print(t.__name__)
- prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
- grids.save_quiz_illustrations(
- "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+
+ # w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ t.__name__ + ".png",
+ w_quizzes,
+ delta=True,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
)
- # exit(0)
+ exit(0)
+
+ q = grids.text2quiz(
+ """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+ )
+
+ grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)
+
+ exit(0)
nb = 1000
- # for t in grids.all_tasks:
- for t in [grids.task_distance]:
+ for t in [
+ # grids.task_bounce,
+ # grids.task_contact,
+ # grids.task_corners,
+ # grids.task_detect,
+ # grids.task_fill,
+ # grids.task_frame,
+ # grids.task_grow,
+ # grids.task_half_fill,
+ # grids.task_isometry,
+ # grids.task_path,
+ # grids.task_replace_color,
+ # grids.task_scale,
+ grids.task_symbols,
+ # grids.task_trajectory,
+ # grids.task_translate,
+ ]:
+ # for t in [grids.task_path]:
start_time = time.perf_counter()
- prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+ w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
delay = time.perf_counter() - start_time
- print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
+ print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
+ grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
exit(0)
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quiz_illustrations(
+ grids.save_quizzes_as_image(
"/tmp",
- "test",
+ "test.png",
prompts[:nb],
answers[:nb],
# You can add a bool to put a frame around the predicted parts
# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, argparse, time, tqdm, os, datetime, warnings
+import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
import torch, torchvision
from torch import nn
from torch.nn import functional as F
-import ffutils
+import ffutils, grids, attae
-import mygpt
-import sky, grids, quiz_machine
+import threading, subprocess
-import threading
+# import torch.multiprocessing as mp
-import torch.multiprocessing as mp
+torch.set_float32_matmul_precision("high")
+
+# torch.set_default_dtype(torch.bfloat16)
######################################################################
parser.add_argument("--resume", action="store_true", default=False)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
-
-########################################
+# ----------------------------------
parser.add_argument("--nb_epochs", type=int, default=10000)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--train_batch_size", type=int, default=None)
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--eval_batch_size", type=int, default=25)
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
-parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
+
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=5e-4)
-########################################
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+# ----------------------------------
+
+parser.add_argument("--model_type", type=str, default="standard")
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--nb_blocks", type=int, default=None)
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
-########################################
-
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="grids")
+# ----------------------------------
parser.add_argument("--nb_threads", type=int, default=1)
parser.add_argument("--gpus", type=str, default="all")
-parser.add_argument("--nb_gpts", type=int, default=5)
+# ----------------------------------
+
+parser.add_argument("--nb_models", type=int, default=5)
+
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
-parser.add_argument("--proba_understands", type=float, default=0.9)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--generation_temperature", type=float, default=1.0)
+parser.add_argument("--proba_hint", type=float, default=0.25)
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--quizzes", type=str, default=None)
######################################################################
)
parser.add_argument(
- "--grids_tasks",
+ "--grids_world_tasks",
type=str,
- default=None,
- help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+ default="replace_color,translate,grow,frame",
+ help="A comma-separated subset of: " + grids_tasks + ".",
)
######################################################################
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
-
-parser.add_argument("--sky_speed", type=int, default=3)
-
-######################################################################
-
args = parser.parse_args()
if args.result_dir is None:
######################################################################
-default_args = {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 100000,
- "nb_test_samples": 10000,
-}
-
-for k, v in default_args.items():
- if getattr(args, k) is None:
- setattr(args, k, v)
-
-######################################################################
-
default_model_args = {
"17K": {
"dim_model": 32,
######################################################################
if args.resume:
- assert os.path.isdir(args.result_dir)
-
+ if not os.path.isdir(args.result_dir):
+ print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
+ exit(1)
else:
try:
os.mkdir(args.result_dir)
def log_string(s):
+ """print the given string prefixed with a time stamps, and log it
+ into log_file is not None"""
+
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
sys.stdout.flush()
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+ f.write(f"{' '.join(sys.argv)}\n")
+
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
log_string(f"argv {' '.join(sys.argv)}")
assert len(gpus) == 0
main_device = torch.device("cpu")
-if args.dirty_debug:
- args.nb_train_samples = 2500
- args.nb_test_samples = 100
-
-if args.physical_batch_size is None:
- args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+ args.train_batch_size = args.batch_size
else:
- assert args.batch_size % args.physical_batch_size == 0
+ assert args.batch_size % args.train_batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.problem == "sky":
- problem = sky.Sky(
- height=args.sky_height,
- width=args.sky_width,
- nb_birds=args.sky_nb_birds,
- nb_iterations=args.sky_nb_iterations,
- speed=args.sky_speed,
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- )
- back_accuracy = False
-elif args.problem == "grids":
- problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks=args.grids_tasks,
- )
- back_accuracy = True
-else:
- raise ValueError
-
-problem.save_some_examples(args.result_dir)
-
-quiz_machine = quiz_machine.QuizMachine(
- problem=problem,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- back_accuracy=back_accuracy,
- batch_size=args.physical_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=main_device,
-)
+######################################################################
+
+
+def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
######################################################################
-log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
-vocabulary_size = quiz_machine.vocabulary_size()
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+ if c_quizzes is None:
+ quizzes = problem.generate_w_quizzes(nb_samples)
+ nb_w_quizzes = quizzes.size(0)
+ nb_c_quizzes = 0
+ else:
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
+
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
+
+ w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ nb_w_quizzes = w_quizzes.size(0)
+ nb_c_quizzes = c_quizzes.size(0)
+
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes = quizzes[i].contiguous()
+
+ log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+ return quizzes
-log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
- with torch.autograd.no_grad():
- model.eval().to(local_device)
+def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
+ input, masks, targets = imt_set.unbind(dim=1)
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.proba_hint
+ ).long() * masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
+ input, masks, targets = imt_set.unbind(dim=1)
+ noise = problem.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
+######################################################################
+# Prediction
- for input in quiz_machine.batches(model, split="test"):
- input = input.to(local_device)
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
+def samples_for_prediction_imt(input):
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
- loss = F.cross_entropy(output.transpose(1, 2), input)
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+def ae_predict(model, imt_set, local_device=main_device):
+ model.eval().to(local_device)
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+ record = []
- log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
+ )
- model.main_test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
+ for imt in src:
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
+ record.append(result)
+
+ return torch.cat(record)
+
+
+def predict_the_four_grids(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+ if with_hints:
+ imt_set = add_hints_imt(imt_set)
+
+ if with_noise:
+ imt_set = add_noise_imt(imt_set)
+
+ result = ae_predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+ return result
+
+
+######################################################################
+
+
+def samples_for_generation_imt(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+
+ noise = problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ masks = input.new_full(input.size(), 1)
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def ae_generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
+
+ # We loop through the iterations first and through the
+ # mini-batches second so that we keep only the samples that have
+ # not stabilized
+
+ all_input = problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
+
+ for it in range(args.diffusion_nb_iterations):
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+ if not all_changed.any():
+ break
+
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),
)
+ for input, masks, changed in src:
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+ update = (1 - mask_changes) * input + mask_changes * output
+ changed[...] = changed & (update != input).max(dim=1).values
+ input[...] = update
-def one_epoch(model, quiz_machine, local_device=main_device):
- model.to(local_device).train()
+ a = all_changed.clone()
+ all_input[a] = sub_input
+ all_masks[a] = sub_masks
+ all_changed[a] = sub_changed
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ return all_input
- nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(model, split="train"):
- input = input.to(local_device)
+######################################################################
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
+ quizzes = generate_quiz_set(
+ args.nb_train_samples if train else args.nb_test_samples,
+ c_quizzes,
+ args.c_quiz_multiplier,
+ )
+
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
+ b_p = samples_for_prediction_imt(q_p)
+ b_p = add_noise_imt(b_p)
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
+
+ # The other half are denoising examples for the generation
+ b_g = samples_for_generation_imt(q_g)
+
+ imt_set = torch.cat([b_p, b_g])
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
+ else:
+ label = "test"
+ model.eval().to(local_device)
+ batch_size = args.eval_batch_size
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for imt in tqdm.tqdm(
+ imt_set.split(batch_size),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // batch_size,
+ delay=10,
+ ):
+ input, masks, targets = imt.unbind(dim=1)
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * masks).mean()
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
+
+ if train:
+ loss.backward()
- nb_train_samples += input.size(0)
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
- loss.backward()
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+######################################################################
- log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
- run_tests(model, quiz_machine, deterministic_synthesis=False)
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+ # Save some images of the prediction results
- model.to(main_device)
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
+
+ correct = (quizzes == result).min(dim=1).values.long()
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+ :, :, 1
+ ]
+ predicted_parts = correct_parts.abs()
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
+ )
+
+ # Save some images of the ex nihilo generation of the four grids
+
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
######################################################################
-# This is the key routine that decides what generated quizzes to keep
+def one_complete_epoch(
+ model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+ one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
+
+ one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
+
+ # Compute the test accuracy
+
+ quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ correct = (quizzes == result).min(dim=1).values.long()
+
+ nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
+ model.test_accuracy = nb_correct / nb_total
-# token_logprobas are NxMxT where M is the number of models
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
+ )
+ save_inference_images(
+ model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
+ )
-def compute_valid_quizzes_(token_logprobas):
- warnings.warn("validation with uniform constraints", RuntimeWarning)
- l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
- return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
+######################################################################
-def compute_valid_quizzes(token_logprobas):
- l = token_logprobas.sum(dim=-1).sort(dim=-1).values
- return (l[:, 0] < math.log(args.proba_not_understands)) & (
- l[:, 1] > math.log(args.proba_understands)
+
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+ return (
+ (prediction != quizzes)
+ .long()
+ .reshape(quizzes.size(0), 4, -1)
+ .sum(dim=2)
+ .max(dim=1)
+ .values
)
-def extract_valid_quizzes_and_logprobas(recorded):
- validated_quizzes, validated_logprobas = [], []
- for quizzes, token_logprobas in recorded:
- validated_indices = compute_valid_quizzes(token_logprobas)
- validated_quizzes.append(quizzes[validated_indices])
- validated_logprobas.append(token_logprobas[validated_indices])
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
+ nb_correct, nb_wrong = 0, 0
- if len(validated_quizzes) > 0:
- return torch.cat(validated_quizzes, dim=0), torch.cat(
- validated_logprobas, dim=0
+ for model in models:
+ model = copy.deepcopy(model).to(local_device).eval()
+ predicted = predict_the_four_grids(
+ model=model,
+ input=quizzes,
+ with_noise=False,
+ with_hints=with_hints,
+ local_device=local_device,
)
- else:
- return None, None
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
+ nb_correct += (nb_mistakes == 0).long()
+ nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
+
+ # print("\n\n", nb_correct, nb_wrong)
+
+ return nb_correct, nb_wrong
######################################################################
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
- nb_to_create = nb_for_train + nb_for_test
+def identity_quizzes(quizzes):
+ quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
+ quizzes[:, 2] == quizzes[:, 3]
+ ).min(dim=1).values
- recorded_quizzes_logprobas = []
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
+ record = []
nb_validated = 0
- while nb_validated < nb_to_create:
- model_for_generation = models[torch.randint(len(models), (1,))]
+ start_time = time.perf_counter()
+ last_log = -1
+
+ while nb_validated < nb_to_generate:
+ # Generate new quizzes
+
+ model = models[torch.randint(len(models), (1,)).item()]
+ model = copy.deepcopy(model).to(local_device).eval()
+ generator_id = model.id
- c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
- model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
+ c_quizzes = ae_generate(
+ model=model, nb=args.eval_batch_size * 10, local_device=local_device
)
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+ c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
if c_quizzes.size(0) > 0:
- token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
- recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+ # Select the ones that are solved properly by some models and
+ # not understood by others
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=True,
+ local_device=local_device,
+ )
+
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
+
+ nb_validated += to_keep.long().sum().item()
+ record.append(c_quizzes[to_keep])
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ if last_log < 0 or duration > last_log + 10:
+ last_log = duration
+ if nb_validated > 0:
+ if nb_validated < nb_to_generate:
+ d = (nb_to_generate - nb_validated) * duration / nb_validated
+ e = (
+ datetime.datetime.now() + datetime.timedelta(seconds=d)
+ ).strftime("%a %H:%M")
+ else:
+ e = "now!"
+ else:
+ e = "???"
+
+ log_string(
+ f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+ )
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
+
+ return torch.cat(record).to("cpu")
+
+
+######################################################################
+
+
+def multithread_execution(fun, arguments):
+ # Single instance, no thread
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
- (
- validated_quizzes,
- validated_logprobas,
- ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
+ records, threads = [], []
- if validated_quizzes is not None:
- nb_validated = validated_quizzes.size(0)
+ def threadable_fun(*args):
+ r = fun(*args)
+ if type(r) is not tuple:
+ r = (r,)
+ records.append(r)
- log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+ for args in arguments:
+ # To get a different sequence between threads
+ log_string(f"dummy_rand {torch.rand(1)}")
+ # torch.rand(1)
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ if records[0] == (None,):
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+ if suffix != "":
+ suffix = "_" + suffix
+
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
)
- # store the new c_quizzes which have been validated
+ log_string(f"wrote ae_*{suffix}.pth")
+
+
+######################################################################
- quiz_machine.reverse_random_half_in_place(validated_quizzes)
- quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
- quiz_machine.store_c_quizzes(
- validated_quizzes[nb_for_train:nb_to_create], for_train=False
+
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
+ c_quizzes = c_quizzes.to(local_device)
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
)
- ######################################################################
- # save images with their logprobas
+ comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
- vq = validated_quizzes[:72]
- vl = validated_logprobas[:72]
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=c_quizzes,
+ comments=comments,
+ delta=True,
+ nrow=8,
+ )
- if vq.size(0) > 0:
- prefix = f"culture_c_quiz_{n_epoch:04d}"
- filename = os.path.join(args.result_dir, prefix + "_logp.pth")
- torch.save(vl, filename)
- # with open(file_name, "w") as logp_file:
- # for l in vl:
- # s = " ".join([str(x.item()) for x in l])
- # logp_file.write(s + "\n")
+ log_string(f"wrote {filename}")
- quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
+######################################################################
+
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
models = []
-for k in range(args.nb_gpts):
- log_string(f"creating model {k} and its w_quizzes")
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
+if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+
+for i in range(args.nb_models):
+ model = model_constructor(
+ vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
dropout=args.dropout,
- ).to(main_device)
+ )
- model.main_test_accuracy = 0.0
- model.id = k
+ # model = torch.compile(model)
- model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
- quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
- model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
- quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
models.append(model)
######################################################################
-if args.resume:
- try:
- for model in models:
- filename = f"gpt_{model.id:03d}.pth"
-
- try:
- d = torch.load(os.path.join(args.result_dir, filename))
- model.load_state_dict(d[0])
- model.main_test_accuracy = d[1]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- try:
- filename = "c_quizzes.pth"
- quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- except:
- log_string(f"error when loading {filename}.")
- exit(1)
-
-######################################################################
+current_epoch = 0
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+if args.resume:
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
-######################################################################
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ log_string(f"successfully loaded {filename}")
+
+ filename = "state.pth"
+ state = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
-# Compute the entropy of the training tokens
+ log_string(f"successfully loaded {filename}")
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+ current_epoch = state["current_epoch"]
+ train_c_quizzes = state["train_c_quizzes"]
+ test_c_quizzes = state["test_c_quizzes"]
######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
######################################################################
-if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-
-if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
-
-log_string(
- f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
-)
+train_c_quizzes, test_c_quizzes = None, None
######################################################################
-if args.dirty_debug:
- args.accuracy_to_make_c_quizzes = 0.0
- args.nb_gpts = 2
- args.nb_new_c_quizzes_for_train = 100
- args.nb_new_c_quizzes_for_test = 10
+for n_epoch in range(current_epoch, args.nb_epochs):
+ start_time = time.perf_counter()
+ state = {
+ "current_epoch": n_epoch,
+ "train_c_quizzes": train_c_quizzes,
+ "test_c_quizzes": test_c_quizzes,
+ }
-######################################################################
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
-for n_epoch in range(args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
- cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
- ##################################################
- # If all the models are good enough, generate new quizzes and
- # re-compute the test errors
+ # --------------------------------------------------------------------
- if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=args.nb_new_c_quizzes_for_train,
- nb_for_test=args.nb_new_c_quizzes_for_test,
- )
-
- filename = "c_quizzes.pth"
- quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
- log_string(f"wrote {filename}")
+ lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
- # Force one epoch of training
- for model in models:
- model.main_test_accuracy = 0.0
+ if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
+ if train_c_quizzes is None:
+ save_models(models, "naive")
- ##################################################
- # Select, improve, and eval the worst model
+ nb_gpus = len(gpus)
+ nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
- ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+ (new_c_quizzes,) = multithread_execution(
+ generate_c_quizzes,
+ [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+ )
- weakest_models = ranked_models[: len(gpus)]
+ save_quiz_image(
+ models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
+ )
- threads = []
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
- for gpu, model in zip(gpus, weakest_models):
- log_string(f"training model {model.id}")
+ train_c_quizzes = (
+ new_c_quizzes
+ if train_c_quizzes is None
+ else torch.cat([train_c_quizzes, new_c_quizzes])
+ )
+ train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
- t = threading.Thread(
- target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ nb_correct, _ = evaluate_quizzes(
+ quizzes=train_c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
)
- threads.append(t)
+ test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
- t.start()
+ for model in models:
+ model.test_accuracy = 0
- for t in threads:
- t.join()
+ if train_c_quizzes is None:
+ log_string("no_c_quiz")
+ else:
+ log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
- # Save the models to disk
+ # --------------------------------------------------------------------
- for model in weakest_models:
- filename = f"gpt_{model.id:03d}.pth"
- torch.save(
- (model.state_dict(), model.main_test_accuracy),
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ weakest_models = ranked_models[: len(gpus)]
+
+ log_string(
+ f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+ )
- # Renew the training samples
+ multithread_execution(
+ one_complete_epoch,
+ [
+ (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
+ for model, gpu in zip(weakest_models, gpus)
+ ],
+ )
- for model in weakest_models:
- quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+ save_models(models)
+ # --------------------------------------------------------------------
-######################################################################
+ duration = time.perf_counter() - start_time
+ str_duration = ""
+ if duration >= 60:
+ str_duration += f"{int(duration)//60}min"
+ str_duration += f"{int(duration)%60}s"
+ str_next = (
+ datetime.datetime.now() + datetime.timedelta(seconds=duration)
+ ).strftime("%H:%M:%S")
+ log_string(f"epoch_duration {str_duration} next_finish {str_next}")
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-# This is an implementation from scratch of a "GPT", that is a model
-# composed of several causal self-attention blocks. It is equipped
-# with a caching mechanism for keys and values to avoid a O(N^3) cost
-# for auto-regression.
-
-import math
-
-import torch
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-# A BracketedSequence is a BxTx... tensor with a first and a nb time
-# steps to compute.
-
-# Modules able to process it expect that they will have to process a
-# first bracket starting at t=0, followed by a succession of brackets
-# that move forward in time, do not overlap, and cover the axis T with
-# no holes.
-#
-# Although it is more general, for a classical prompt-conditioned
-# auto-regressive process it will be a first bracket starting at 0 and
-# of arbitrary length for the "prompt", followed by brackets of length
-# 1 for the successive tokens.
-#
-# Modules able to process brackets may implement a cache that is
-# resetted when the input bracket starts at t=0
-
-
-class BracketedSequence:
- def __init__(self, x, first=None, nb=None):
- self.x = x
- self.first = 0 if first is None else first
- self.nb = x.size(1) if nb is None else nb
-
- def slice(self):
- return self.x[:, self.first : self.first + self.nb]
-
- def complete(self):
- return self.first == 0 and self.nb == self.x.size(1)
-
-
-######################################################################
-
-
-class CacheWrapper(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- if bs.first == 0:
- y = self.f(bs.slice())
- self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
- self.cache_y[:, bs.first : bs.first + bs.nb] = y
- else:
- self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
-
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
-
-
-##############################
-
-
-class AddPositionalEncoding(nn.Module):
- def __init__(self, len_max):
- super().__init__()
- self.len_max = len_max
-
- # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
- def forward(self, bs):
- if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
- None, :
- ]
- k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
- )
- self.cache_y = bs.x.new(bs.x.size())
-
- self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
- )
-
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class QKVAttention(nn.Module):
- def __init__(
- self,
- dim_in,
- dim_qk,
- dim_v,
- nb_heads=1,
- causal=False,
- attention_dropout=0.0,
- ):
- super().__init__()
-
- def randw(*d):
- return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
-
- self.causal = causal
- self.attention_dropout = attention_dropout
- self.record_attention = False
-
- self.w_q = randw(nb_heads, dim_qk, dim_in)
- self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
-
- def forward(self, bs_q):
- x_q = bs_q.x
-
- assert (
- self.causal or bs_q.complete()
- ), "Partial evaluation is only possible for causal models"
-
- if bs_q.first == 0:
- self.cache_k = x_q.new_zeros(
- x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
- )
- self.cache_v = x_q.new_zeros(
- x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
- )
- self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
-
- q = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
- )
-
- self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
- )
- self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
- )
-
- a = torch.einsum(
- "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
- ) / math.sqrt(self.w_q.size(1))
-
- if self.causal:
- if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
- a = a.masked_fill(
- self.cache_attzero[
- :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
- ],
- float("-inf"),
- )
-
- a = a.softmax(dim=3)
-
- if self.record_attention:
- self.a = a
-
- a = F.dropout(a, self.attention_dropout, self.training)
-
- y = torch.einsum(
- "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
- ).flatten(2)
-
- self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
-
- return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
-
-
-##############################
-
-
-class NoiseInjector(nn.Module):
- def __init__(self):
- super().__init__()
- self.noise_std = 0.0
-
- def forward(self, x):
- if self.noise_std > 0:
- x = x + torch.randn(x.size(), device=x.device) * self.noise_std
- return x
-
-
-def set_noise_injection(model, noise_std):
- for m in model.modules():
- if isinstance(m, NoiseInjector):
- m.noise_std = noise_std
-
-
-##############################
-
-
-class MyGPT(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- causal=False,
- dropout=0.0,
- len_max=1e5,
- ):
- super().__init__()
-
- assert dim_model % nb_heads == 0
-
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
- )
-
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- NoiseInjector(),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- causal=causal,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- NoiseInjector(),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
- ]
-
- self.trunk = nn.Sequential(*trunk_blocks)
-
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
-
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
-
- def forward(self, bs):
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
- bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- bs = self.embedding(bs)
- bs = self.trunk(bs)
- bs = self.readout(bs)
- return bs
-
- def record_attention(self, v=True):
- for m in self.modules():
- if isinstance(m, QKVAttention):
- m.record_attention = v
-
- def retrieve_attention(self):
- a = []
- for m in self.modules():
- if isinstance(m, QKVAttention):
- a.append(m.a)
- return a
-
-
-######################################################################
-
-if __name__ == "__main__":
- print("Basic check.")
-
- vocabulary_size = 3
- x = torch.randint(vocabulary_size, (1, 5))
-
- model = MyGPT(
- vocabulary_size=vocabulary_size,
- dim_model=4,
- dim_keys=2,
- dim_hidden=2,
- nb_heads=2,
- nb_blocks=2,
- dropout=0.1,
- causal=True,
- )
-
- model.eval()
- y1 = model(BracketedSequence(x)).x
- y2 = torch.randn_like(y1)
- for s in range(x.size(1)):
- z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.slice()
-
- print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
-
-######################################################################
else:
return self.queue.qsize() * self.chunk_size
- def nb_token_values(self):
- pass
+ def fill_cache(self):
+ while True:
+ quizzes = self.generate_w_quizzes_(self.chunk_size)
+ self.queue.put(quizzes.to("cpu"), block=True)
+
+ def generate_w_quizzes(self, nb, progress_bar=True):
+ if self.queue is None:
+ return self.generate_w_quizzes_(nb)
+
+ if self.rest is not None:
+ quizzes = rest
+ else:
+ quizzes = []
+
+ self.rest = None
+
+ n = sum([q.size(0) for q in quizzes])
+
+ if progress_bar:
+ with tqdm.tqdm(
+ total=nb, dynamic_ncols=True, desc="world generation", delay=10
+ ) as pbar:
+ while n < nb:
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+ pbar.update(q.size(0))
+ else:
+ while n < nb:
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+
+ quizzes = torch.cat(quizzes, dim=0)
+ assert n == quizzes.size(0)
+
+ k = n - nb
+
+ if k > 0:
+ rest = quizzes[-k:]
+ quizzes = quizzes[:-k]
+
+ return quizzes
+
+ ######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
pass
# The one to implement, returns two tensors nb x D and nb x D'
- def generate_prompts_and_answers_(self, nb):
+ def generate_w_quizzes_(self, nb):
pass
# save a file to vizualize quizzes, you can save a txt or png file
):
pass
- def fill_cache(self):
- while True:
- prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
-
- self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
-
- def generate_prompts_and_answers(self, nb):
- if self.queue is None:
- return self.generate_prompts_and_answers_(nb)
-
- if self.rest is not None:
- prompts, answers = rest
- else:
- prompts, answers = [], []
-
- self.rest = None
-
- n = sum([p.size(0) for p in prompts])
-
- with tqdm.tqdm(
- total=nb,
- dynamic_ncols=True,
- desc="world generation",
- ) as pbar:
- while n < nb:
- p, s = self.queue.get(block=True)
- prompts.append(p)
- answers.append(s)
- n += p.size(0)
- pbar.update(p.size(0))
-
- prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
- assert n == prompts.size(0)
-
- k = n - nb
-
- if k > 0:
- rest = (prompts[-k:], answers[-k:])
- prompts, answers = prompts[:-k], answers[:-k]
-
- return prompts, answers
-
def save_some_examples(self, result_dir):
pass
+
+ ######################################################################
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-import threading
-
-######################################################################
-# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
-# | X != Y)
-
-
-# output is NxCxT and target is NxT
-def confusion(output, target, reduction="mean"):
- N, C, T = output.shape
- output = output.permute(0, 2, 1).reshape(-1, C)
- target = target.flatten()
- all_t = torch.arange(N * T, device=output.device)
- output = output.log_softmax(dim=-1)
- result = -output[all_t, target]
-
- output[all_t, target] = float("-inf")
- output = output.log_softmax(dim=-1)
- e = output.exp()
- output[all_t, target] = 0
- result = result - (output * e).sum(-1)
-
- if reduction == "none":
- return result.reshape(N, T)
- elif reduction == "mean":
- return result.reshape(N, T).mean()
- elif reduction == "sum":
- return result.reshape(N, T).sum()
- else:
- raise ValueError(f"unknown reduction '{reduction}'.")
-
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
- model,
- input,
- ar_mask,
- seq_logproba,
- temperature,
- deterministic_synthesis,
-):
- to_generate = (ar_mask.sum(0) > 0).nonzero()
-
- if to_generate.min() > 0:
- model(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = model(BracketedSequence(input, s, 1)).x
-
- logits = output[:, s]
-
- logits = (logits / temperature).log_softmax(dim=-1)
-
- if deterministic_synthesis:
- t_next = logits.argmax(-1)
- else:
- dist = torch.distributions.categorical.Categorical(logits=logits)
- t_next = dist.sample()
-
- all_n = torch.arange(t_next.size(0))
-
- seq_logproba += logits[all_n, t_next]
-
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-def masked_inplace_autoregression(
- model,
- batch_size,
- input,
- ar_mask,
- seq_logproba,
- temperature,
- deterministic_synthesis,
- forbidden_tokens=None,
- logit_biases=None,
- progress_bar_desc=None,
- device=torch.device("cpu"),
-):
- assert input.size() == ar_mask.size()
-
- batches = zip(
- input.split(batch_size),
- ar_mask.split(batch_size),
- seq_logproba.split(batch_size),
- )
-
- if progress_bar_desc is not None:
- batches = tqdm.tqdm(
- batches,
- dynamic_ncols=True,
- desc=progress_bar_desc,
- total=(input.size(0) + batch_size - 1) // batch_size,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, ar_mask, seq_logproba in batches:
- one_batch_masked_inplace_autoregression(
- model=model,
- input=input,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=deterministic_synthesis,
- )
-
- model.train(t)
-
-
-######################################################################
-
-
-class QuizMachine:
- def indices_forward_and_backward(self, quizzes):
- i_forward = quizzes[:, 0] == self.token_forward
- j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
- i_backward = quizzes[:, 0] == self.token_backward
- j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
- assert torch.logical_or(
- torch.logical_and(i_forward, j_forward),
- torch.logical_and(i_backward, j_backward),
- ).all()
- return i_forward, i_backward
-
- def non_trivial(self, quizzes):
- quizzes = quizzes.clone()
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
- return torch.logical_not(
- self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
- )
- )
-
- def reverse_time(self, quizzes):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
- forward_to_backward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
- quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
- quizzes[:, 1 : 1 + self.prompt_len],
- ],
- dim=1,
- )
-
- forward_to_backward[:, 0] = self.token_backward
- forward_to_backward[:, 1 + self.answer_len] = self.token_backward
-
- backward_to_forward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.answer_len :],
- quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
- quizzes[:, 1 : 1 + self.answer_len],
- ],
- dim=1,
- )
-
- backward_to_forward[:, 0] = self.token_forward
- backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
-
- m = i_forward.long()[:, None]
-
- return m * forward_to_backward + (1 - m) * backward_to_forward
-
- def reverse_random_half_in_place(self, quizzes):
- i = torch.rand(quizzes.size(0)) < 0.5
- if i.any():
- quizzes[i] = self.reverse_time(quizzes[i])
-
- def make_ar_mask(self, quizzes, first=False):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
- t = torch.arange(quizzes.size(1), device=quizzes.device)
-
- if first:
- m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
- m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
- else:
- m_forward = (t >= 2 + self.prompt_len).long()
- m_backward = (t >= 2 + self.answer_len).long()
-
- m = i_forward.long()[:, None]
-
- return m * m_forward + (1 - m) * m_backward
-
- def generate_token_sequences(self, nb):
- prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
- if self.prompt_len is None:
- self.prompt_len = prompts.size(1)
-
- if self.answer_len is None:
- self.answer_len = answers.size(1)
-
- assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
- result = []
-
- for prompt, answer in zip(prompts, answers):
- a = [
- torch.tensor([self.token_forward]),
- prompt,
- torch.tensor([self.token_forward]),
- answer,
- ]
-
- result.append(torch.cat(a, dim=0)[None, :])
-
- return torch.cat(result, dim=0)
-
- def __init__(
- self,
- problem,
- nb_train_samples,
- nb_test_samples,
- back_accuracy,
- batch_size,
- result_dir,
- logger,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- v = problem.nb_token_values()
- self.token_forward = v
- self.token_backward = v + 1
- self.nb_token_values = v + 2
-
- self.problem = problem
- self.back_accuracy = back_accuracy
- self.batch_size = batch_size
- self.device = device
- self.logger = logger
- self.prompt_len = None
- self.answer_len = None
-
- self.LOCK_C_QUIZZES = threading.Lock()
- self.train_c_quizzes = []
- self.test_c_quizzes = []
-
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- quizzes,
- mistakes=None,
- ):
- quizzes = quizzes.clone().to("cpu")
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-
- predicted_prompts = n_backward.long()
- predicted_answers = 1 - predicted_prompts
- if mistakes is not None:
- # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes.to("cpu")
- predicted_answers *= mistakes.to("cpu")
- else:
- # 0/2 ~ not-to-predict / to predict
- predicted_prompts *= 2
- predicted_answers *= 2
-
- self.problem.save_quiz_illustrations(
- result_dir,
- filename_prefix,
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
- predicted_prompts,
- predicted_answers,
- )
-
- def vocabulary_size(self):
- return self.nb_token_values
-
- ######################################################################
-
- def batches(self, model, split="train", desc=None):
- assert split in {"train", "test"}
-
- with self.LOCK_C_QUIZZES:
- if split == "train":
- w_quizzes = model.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = model.test_w_quizzes
- c_quizzes = self.test_c_quizzes
-
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
-
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
-
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
-
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
-
- # Shuffle
- input = input[torch.randperm(input.size(0))]
-
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- ######################################################################
-
- def produce_results(
- self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
- ):
- def compute_accuracy(input, log_prefix=None):
- input = input.to(self.device)
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
-
- n_forward = input[:, 0] == self.token_forward
- n_backward = input[:, 0] == self.token_backward
-
- correct[n_forward] = (
- (input[n_forward] == result[n_forward]).long().min(dim=1).values
- )
-
- if self.back_accuracy and n_backward.any():
- # accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.reverse_time(result[n_backward])
- back_input[:, 2 + self.prompt_len :] = input[
- n_backward, 1 : 1 + self.answer_len
- ]
- _, correct[n_backward] = compute_accuracy(back_input)
-
- if log_prefix is not None:
- forward_nb_correct = correct[n_forward].sum()
- forward_nb_total = correct[n_forward].size(0)
- backward_nb_correct = correct[n_backward].sum()
- backward_nb_total = correct[n_backward].size(0)
-
- self.logger(
- f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
- )
-
- return result, correct
-
- # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
-
- test_result, test_correct = compute_accuracy(
- model.test_w_quizzes[:nmax], log_prefix="test"
- )
-
- main_test_accuracy = test_correct.sum() / test_correct.size(0)
- self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
- ##############################
-
- self.save_quiz_illustrations(
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=test_result[:72],
- mistakes=test_correct[:72] * 2 - 1,
- )
-
- return main_test_accuracy
-
- ######################################################################
-
- def renew_w_quizzes(self, model, nb, for_train=True):
- input = model.train_w_quizzes if for_train else model.test_w_quizzes
- nb = min(nb, input.size(0))
- input[:-nb] = input[nb:].clone()
- fresh_w_quizzes = self.generate_token_sequences(nb)
- self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to("cpu")
-
- ######################################################################
-
- def store_c_quizzes(self, new_c_quizzes, for_train=True):
- with self.LOCK_C_QUIZZES:
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
- else:
- self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
-
- def save_c_quizzes(self, filename):
- torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
-
- def load_c_quizzes(self, filename):
- self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
-
- ######################################################################
-
- def solution_token_logprobas(self, models, c_quizzes):
- logproba = c_quizzes.new_zeros(
- c_quizzes.size(0),
- len(models),
- c_quizzes.size(1),
- device=self.device,
- dtype=torch.float32,
- )
-
- for model in models:
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, l in zip(
- c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
- ):
- input = input.to(self.device)
- ar_mask = self.make_ar_mask(input)
- output = model(mygpt.BracketedSequence(input)).x
- l[:, model.id] = (
- -F.cross_entropy(
- output.transpose(1, 2), input, reduction="none"
- )
- * ar_mask
- )
-
- model.train(t)
-
- return logproba.to("cpu")
-
- ###############################################################
-
- def compute_correctness(
- self,
- c_quizzes,
- models_for_validation,
- bidirectional_validation=False,
- deterministic_validation=True,
- ):
- if bidirectional_validation:
- backward_c_quizzes = self.forward_to_backward(c_quizzes)
-
- seq_logproba = torch.zeros(
- c_quizzes.size(0),
- max([m.id for m in models_for_validation]) + 1,
- device=self.device,
- )
-
- nb_correct = 0
-
- seq_logproba[...] = 0.0
-
- for model in models_for_validation:
- result = c_quizzes.clone()
-
- ar_mask = self.make_ar_mask(result)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
- deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving c_quizzes",
- device=self.device,
- )
-
- correct = (c_quizzes == result).long().min(dim=-1).values
-
- if bidirectional_validation:
- backward_result = backward_c_quizzes.clone()
-
- ar_mask = self.make_ar_mask(backward_result)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=backward_result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
- deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving backward c_quizzes",
- device=self.device,
- )
-
- backward_correct = (
- (backward_c_quizzes == backward_result).long().min(dim=-1).values
- )
-
- correct *= backward_correct
-
- # endif
-
- nb_correct += correct
-
- return nb_correct, seq_logproba
-
- ###############################################################
-
- def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
- c_quizzes = torch.empty(
- nb,
- self.prompt_len + self.answer_len + 2,
- device=self.device,
- dtype=torch.int64,
- )
-
- seq_logproba = torch.zeros(nb, device=self.device)
-
- # First, we generate the answer at high temperature
-
- c_quizzes[:, 0] = self.token_backward
- c_quizzes[:, 1 + self.answer_len] = self.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, first=True),
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- # Then, we generate the prompt at low temperature
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
- temperature=1 / temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- # Then we return the quizz, and re-generate the response, now
- # at low temperature
-
- c_quizzes = self.reverse_time(c_quizzes)
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
- temperature=1 / temperature,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- return c_quizzes.to("cpu")
--- /dev/null
+%% -*- mode: latex; mode: reftex; mode: flyspell; coding: utf-8; tex-command: "pdflatex.sh" -*-
+
+%% Any copyright is dedicated to the Public Domain.
+%% https://creativecommons.org/publicdomain/zero/1.0/
+%% Written by Francois Fleuret <francois@fleuret.org>
+
+\documentclass[11pt,a4paper,oneside]{article}
+\usepackage[paperheight=15cm,paperwidth=8cm,top=2mm,bottom=15mm,right=5mm,left=5mm]{geometry}
+%\usepackage[a4paper,top=2.5cm,bottom=2cm,left=2.5cm,right=2.5cm]{geometry}
+\usepackage[utf8]{inputenc}
+\usepackage{amsmath,amssymb,dsfont}
+\usepackage[pdftex]{graphicx}
+\usepackage[colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue]{hyperref}
+\urlstyle{same}
+\usepackage{tikz}
+\usetikzlibrary{arrows,arrows.meta,calc}
+\usetikzlibrary{patterns,backgrounds}
+\usetikzlibrary{positioning,fit}
+\usetikzlibrary{shapes.geometric,shapes.multipart}
+\usetikzlibrary{patterns.meta,decorations.pathreplacing,calligraphy}
+\usetikzlibrary{tikzmark}
+\usetikzlibrary{decorations.pathmorphing}
+\usepackage[round]{natbib}
+\usepackage[osf]{libertine}
+\usepackage{microtype}
+
+\usepackage{mleftright}
+
+\usepackage{enumitem}
+\setlist[itemize]{leftmargin=0pt,itemindent=1em,itemsep=2ex}
+\setlist{nosep} % or \setlist{noitemsep} to leave space around whole list
+
+\newcommand{\setmuskip}[2]{#1=#2\relax}
+\setmuskip{\thinmuskip}{1.5mu} % by default it is equal to 3 mu
+\setmuskip{\medmuskip}{2mu} % by default it is equal to 4 mu
+\setmuskip{\thickmuskip}{3.5mu} % by default it is equal to 5 mu
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+%\renewcommand{\baselinestretch}{1.3}
+%\setlength{\tabcolsep}{0pt}
+%\renewcommand{\arraystretch}{1.0}
+
+\def\argmax{\operatornamewithlimits{argmax}}
+\def\argmin{\operatornamewithlimits{argmin}}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\def\given{\,\middle\vert\,}
+\def\proba{\operatorname{P}}
+\newcommand{\seq}{{S}}
+\newcommand{\expect}{\mathds{E}}
+\newcommand{\variance}{\mathds{V}}
+\newcommand{\empexpect}{\hat{\mathds{E}}}
+\newcommand{\mutinf}{\mathds{I}}
+\newcommand{\empmutinf}{\hat{\mathds{I}}}
+\newcommand{\entropy}{\mathds{H}}
+\newcommand{\empentropy}{\hat{\mathds{H}}}
+\newcommand{\ganG}{\mathbf{G}}
+\newcommand{\ganD}{\mathbf{D}}
+\newcommand{\ganF}{\mathbf{F}}
+
+\newcommand{\dkl}{\mathds{D}_{\mathsf{KL}}}
+\newcommand{\djs}{\mathds{D}_{\mathsf{JS}}}
+
+\newcommand*{\vertbar}{\rule[-1ex]{0.5pt}{2.5ex}}
+\newcommand*{\horzbar}{\rule[.5ex]{2.5ex}{0.5pt}}
+
+\def\positionalencoding{\operatorname{pos-enc}}
+\def\concat{\operatorname{concat}}
+\def\crossentropy{\LL_{\operatorname{ce}}}
+
+\newcommand{\separator}{\begin{center}
+*
+\end{center}}
+
+\newcommand{\pic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.25]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newcommand{\birdpic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.35]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newenvironment{example}{%
+
+\vspace*{2ex}
+
+\begin{minipage}{\textwidth}
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+}{%
+\end{minipage}
+}
+
+\begin{document}
+
+\vspace*{-3ex}
+
+\begin{center}
+
+{\Large Self-Generated Culture}
+
+Fran\c cois Fleuret
+
+\today
+
+\vspace*{2ex}
+
+\centerline{\color{red}(work in progress, to be updated)}
+
+\medskip
+
+\centerline{\url{https://fleuret.org/public/culture/culture.pdf}}
+
+\end{center}
+
+\section{Introduction}
+
+The hypothesis behind this experiment is that high-level abstract
+thinking is fueled by social competition.
+
+A group of communicating agents that try to demonstrate their
+cognitive superiority would end up developing a rich and consistent
+culture.
+
+\subsection{Setup}
+
+The experiment is designed with a group of GPTs that alternatively
+learn to solve quizzes and generate new ones.
+
+A ``quiz'' is a pair composed of a prompt and a solution, both being
+sequence of tokens.
+
+We differentiate \textbf{world quizzes} that follow pre-defined and
+fixed regularities, and mimic the world's physical and environmental
+patterns that an organism has to grasp to survive, and \textbf{culture
+ quizzes} that are generated by the GPTs, and mimic the knowledge one
+has to master to perform socially.
+
+
+We train five GPTs on a a very large set of ``world quizzes''
+generated randomly. These models are trained to generate both the
+solution given the prompt, and the prompt given the solution.
+
+This is achieved by using for training both ``forward sequences'',
+composed of a token \texttt{[fwd]}, followed by the prompt's tokens,
+followed by another token \texttt{[fwd]}, followed by the solution's
+tokens, or ``backward sequences'' composed of a token \texttt{[bck]},
+followed by the solution's tokens, followed by another token
+\texttt{[bck]}, followed by the prompt's tokens,
+
+\subsection{Generating Culture Quizzes}
+
+When their accuracy get above $95\%$ we generate new quizzes as follows:
+%
+\begin{enumerate}
+
+\item generate a solution (without conditioning) at temperature $T=2$,
+ then generate a prompt for that solution at temperature $T=1/2$, and
+ then generate a solution for that prompt at temperature $T=1/2$.
+
+\item generate one solution for that prompt with each of the $5$ GPTs
+ at temperature $T=1$, if $4$ of them generate the correct solution,
+ validate that quiz and include it in the training data.
+
+\end{enumerate}
+
+This criterion assures that the new quizzes are both solvable and
+sophisticated, and incrementally complexify the culture. Imposing both
+direction prevents the generation of quizzes which are not trivial
+only because the prompt has been randomly degraded.
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Grid Quizzes}
+
+\subsection{World Quizzes}
+
+We define several types of quizzes and implement algorithmic
+procedures to generate randomly as many examples from each that we
+need.
+
+In these quizzes, the prompt is made of three grids $A, f(A), B$ and
+the solution is a single grid $f(B)$.
+
+\subsubsection{Half Fill}
+
+\pic{pics/task_color_grow.png}{``half fill''}
+
+The first grid contains three rectangles, each with a vertical or an
+horizontal line of another color in its middle. The second grid is
+identical with one of the rectangle having one half filled. The third
+grid contains three rectangles of identical colors as the firs grid,
+of different size and locations. The solution is obtained by filling
+similarly one of the half of a rectangle of the third image.
+
+\subsubsection{Detect}
+
+\pic{pics/task_detect.png}{``detect''}
+
+The first grid contains three rectangles, the second has two pixels of
+same colors located in the top-left corner of two of them. The
+solution is obtained by marking in the fourth image the top-left
+corners of the rectangles of same colors in the third.
+
+\subsubsection{Frame}
+
+\pic{pics/task_frame.png}{``frame''}
+
+The first grid contains three rectangles, and the second is identical
+except that one rectangle has been replaced by its frame. The same
+should be done to the similarly colored rectangles of the third grid
+to obtain the solution.
+
+\subsubsection{Grow}
+
+\pic{pics/task_grow.png}{``grow''}
+
+The first grid contains three rectangles, one of them getting one
+pixel thicker or thinner in the second. The same should be done to the
+similarly colored rectangles of the third grid to get the solution.
+
+\subsubsection{Replace color}
+
+\pic{pics/task_replace_color.png}{``replace color''}
+
+The first grid contains three rectangles, the second is obtained by
+changing one of the colors. The same should be done to the third grid
+to obtain the solution.
+
+\subsubsection{Translate}
+
+\pic{pics/task_translate.png}{``translate''}
+
+The first grid contains three rectangles. The second is obtained by
+displacing one of them by one pixel in both direction. The solution is
+obtained by applying the same motion to the similarly colored
+rectangle in the third grid.
+
+%% \subsubsection{Bounce}
+
+%% \pic{pics/task_bounce.png}{``bounce''}
+
+%% The solution should join the two pixels of same color, with a path of
+%% another color, starting in the direction indicated by a pixel of that
+%% color, and changing direction only when colliding with a pixel of a
+%% third color or one of the lattice border.
+
+%% \subsubsection{count}
+
+%% \pic{pics/task_count.png}{``count''}
+
+%% \subsubsection{scale}
+
+%% \pic{pics/task_scale.png}{``scale''}
+
+%% \subsubsection{trajectory}
+
+%% \pic{pics/task_trajectory.png}{``trajectory''}
+
+\subsection{Culture Quizzes}
+
+We list here some generated quizzes that exhibit features that were not present in the ``world quizzes'' used for training.
+
+\bigskip
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0110_N4_validated/quiz_63.png}{0110/63}
+
+\pic{pics/culture_c_quiz_0115_N4_validated/quiz_37.png}{0115/37}
+
+The quizzes ``frame'' and ``half fill'' have been combined in a single
+quiz.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0120_N4_validated/quiz_05.png}{0110/05}
+
+The ``frame'' quiz has been generalized to non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_01.png}{0078/01}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_02.png}{0078/02}
+
+More rectangles were added as distractors.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0087_N4_validated/quiz_62.png}{0087/62}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_04.png}{0102/04}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_11.png}{0102/11}
+
+\pic{pics/culture_c_quiz_0108_N4_validated/quiz_31.png}{0108/31}
+
+Variation of ``Detect'' with location markers colored according to the
+color of the rectangle they mark.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_16.png}{0078/16}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_21.png}{0084/21}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_42.png}{0078/42}
+
+\pic{pics/culture_c_quiz_0089_N4_validated/quiz_28.png}{0089/28}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_00.png}{0084/00}
+
+Variations of ``Half Fill'', ``Detect'', ``Translate'', ``Grow'', and
+``Frame'' with a number of rectangles not equal to three.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_27.png}{0078/27}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_18.png}{0078/18}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_45.png}{0086/45}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_37.png}{0078/37}
+
+Variations of ``Half Fill'' where the shapes to change have more
+complex coloring.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_30.png}{0078/30}
+
+Variation of ``Translate'' where the moving part is occluded, which
+was never the case.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_31.png}{0078/31}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_10.png}{0084/10}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_12.png}{0084/12}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_23.png}{0086/23}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_28.png}{0086/28}
+
+Variations of ``Half Fill'' with non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_60.png}{0078/60}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_41.png}{0084/41}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_49.png}{0084/49}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_04.png}{0086/04}
+
+Variations of ``Half Fill'' with two colors or two rectangles have to
+be modified.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0111_N4_validated/quiz_23.png}{0111/23}
+
+Variation of ``Frame'' with no rectangle of adequate size to be
+modified.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Bird World}
+
+These results were obtained with a slightly different procedure. In
+particular the quizzes were validated if the models could predict both
+the solution from the prompt and the prompt from the solution. We
+report them since they exhibit the same patterns of generalization
+although they are quite different.
+
+\subsection{World Quizzes}
+
+The initial set of quizzes consist of predicting the dynamics of a
+very simple world: A $6 \times 8$ grid with three colored ``birds'' moving in
+a straight line, possibly bouncing on the grid's borders. There are
+ten different colors.
+%
+\birdpic{pics/examples_train.png}{}
+%
+
+In each on these quizzes, $A$ is the left image serialized in
+raster-scan order as a sequence of $6 \times 8 = 48$ tokens, $d$ is
+either the token ``forward'' or the token ``backward'', and $B$ is the
+right image, also serialized. The direction of prediction is chosen at
+random.
+
+\subsection{Culture quizzes}
+
+This procedure results in the discovery of patterns which are not
+present in the original quizzes:
+
+\begin{example}
+
+\birdpic{pics/4_birds_1.png}{}
+
+\birdpic{pics/5_birds_1.png}{}
+
+\birdpic{pics/6_birds_1.png}{}
+
+More birds.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_2.png}{}
+
+\birdpic{pics/other_shapes_3.png}{}
+
+New bird shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_1.png}{}
+
+\birdpic{pics/occlusions_1.png}{}
+
+Occlusions.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Various thoughts}
+
+\begin{itemize}
+
+\item The whole process can be envisioned as natural selection of
+ quizzes in the representation landscape of GPTs. There probably is a
+ subtle relation between the temperature (mutation rate) and the
+ number of models used to validate with the ``all but one'' criterion
+ (survival criterion).
+
+\item The ``all but one'' could be ``all but K'', and there may be
+ some information-theoretical thing, where the goal is to maximize
+ mutual information, with $K=N$ being total randomness, so high
+ entropy but no structure, and $K=0$ is total determinism, so no
+ information to share.
+
+\item The setup does not push toward any specific invariance or
+ property in the generated quizzes, their consistency is entirely due
+ to the statistics of the ``world quizzes'' that remain in the
+ training set, and to the GPTs' inductive biased.
+
+\item The GPTs obviously get a sense of objectness and 2d topology
+ early on, since they rapidly increase the number of birds and
+ ``discover'' occlusion even though they never was in the world
+ quizzes.
+
+\item There may not be so many problems that can be cast as pairs of
+ patterns that are each a deterministic function of the other, which
+ is probably critical here.
+
+\item This overall process probably fight the ``simplicity bias'': If
+ a model is lacking a ``cue'' that the others have, there will
+ rapidly be quizzes that require this cue, they will be added to the
+ training data, and that model will catch up.
+
+\item The randomness of the process probably allow to even go beyond
+ just synchronizing the abilities of the models. There may be some
+ additional complexification of quizzes that get accepted by chance.
+
+\item It can be parallelized by dispatching the GPTs across multiples
+ nodes, and avoiding a quadratic cost by limiting the validation of
+ the quizzes to a subset of them.
+
+\item The current process to generate new quizzes, which simply
+ samples them at random is very rudimentary and probably not
+ sufficient in a real-data setup. It can probably be supplemented
+ with a MCTS-type search.
+
+\item There may be already in the generated quizzes some structure
+ that \emph{we} do not pick up (e.g. certain color or motion
+ patterns).
+
+\end{itemize}
+
+\section*{Appendix}
+
+The code is available at
+
+\medskip
+
+\centerline{\url{https://fleuret.org/git/culture}}
+
+The experiments are done with a GTX 4090.
+
+The GPT used has 37M parameters and the following structure:
+
+\begin{center}
+\begin{tabular}{lc}
+ \texttt{dim\_model} & 512 \\
+ \texttt{dim\_keys} & 64 \\
+ \texttt{dim\_hidden} & 2048 \\
+ \texttt{nb\_heads} & 8 \\
+ \texttt{nb\_blocks} & 12
+\end{tabular}
+\end{center}
+
+Adam, $\eta = 1e-4$, no scheduling.
+
+There are $N_{\text{train}}=250'000$ original quizzes for training and
+$N_{\text{test}} = 10'000$ for test.
+
+At each epoch, for both train and test samples, we mix original
+quizzes and the generated ones.
+
+For training for instance, if there are less than $N_{\text{train}}/2$
+new quizzes, we take all of them, otherwise we sample
+$N_{\text{train}}/2$ of them without replacement, and then we sample
+without replacement enough original quizzes to get $N_{\text{train}}$
+samples in total.
+
+We proceed similarly to get $N_{\text{test}}$ samples for test.
+
+\end{document}
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Sky(problem.Problem):
- colors = torch.tensor(
- [
- [255, 255, 255],
- [255, 0, 0],
- [0, 192, 0],
- [0, 0, 255],
- [255, 192, 0],
- [0, 255, 255],
- [255, 0, 255],
- [192, 255, 192],
- [255, 192, 192],
- [192, 192, 255],
- [192, 192, 192],
- ]
- )
-
- token_background = 0
- first_bird_token = 1
- nb_bird_tokens = colors.size(0) - 1
-
- token2char = (
- "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
- )
-
- def __init__(
- self,
- height=6,
- width=8,
- nb_birds=3,
- speed=2,
- nb_iterations=2,
- avoid_collision=True,
- max_nb_cached_chunks=None,
- chunk_size=None,
- nb_threads=-1,
- ):
- super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
- self.height = height
- self.width = width
- self.nb_birds = nb_birds
- self.speed = speed
- self.nb_iterations = nb_iterations
- self.avoid_collision = avoid_collision
-
- def generate_frame_sequences(self, nb):
- frame_sequences = []
-
- for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
- i, j, vi, vj = (
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- torch.empty(self.nb_birds, dtype=torch.int64),
- )
-
- def collision_okay():
- if not self.avoid_collision:
- return True
-
- count = torch.zeros(self.height, self.width, dtype=torch.int64)
-
- for n in range(self.nb_birds):
- count[i[n], j[n]] += 1
- count[i[n] - vi[n], j[n]] += 1
- count[i[n], j[n] - vj[n]] += 1
-
- return count.max() <= 1
-
- col = (
- torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
- + 1
- )
-
- while True:
- while True:
- for n in range(self.nb_birds):
- while True:
- i[n] = torch.randint(self.height, (1,))
- j[n] = torch.randint(self.width, (1,))
- vm = torch.randint(4, (1,))
- vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
- if (
- i[n] - vi[n] >= 0
- and i[n] - vi[n] < self.height
- and j[n] - vj[n] >= 0
- and j[n] - vj[n] < self.width
- ):
- break
-
- if collision_okay():
- break
-
- result = torch.zeros(
- self.nb_iterations * self.speed,
- self.height,
- self.width,
- dtype=torch.int64,
- )
-
- fine = torch.empty(self.nb_iterations * self.speed)
-
- t_to_keep = (
- torch.arange(self.nb_iterations, device=result.device) * self.speed
- )
-
- for l in range(self.nb_iterations * self.speed):
- fine[l] = collision_okay()
- for n in range(self.nb_birds):
- c = col[n]
- result[l, i[n], j[n]] = c
- result[l, i[n] - vi[n], j[n]] = c
- result[l, i[n], j[n] - vj[n]] = c
-
- if (i[n] == 0 and vi[n] == -1) or (
- i[n] == self.height - 1 and vi[n] == 1
- ):
- vi[n] = -vi[n]
-
- if (j[n] == 0 and vj[n] == -1) or (
- j[n] == self.width - 1 and vj[n] == 1
- ):
- vj[n] = -vj[n]
-
- i[n] += vi[n]
- j[n] += vj[n]
-
- result = result[t_to_keep]
- fine = fine[t_to_keep]
-
- if fine[-1]:
- break
-
- frame_sequences.append(result)
-
- return frame_sequences
-
- ######################################################################
-
- def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
- m = torch.logical_and(
- x >= 0, x < self.first_bird_token + self.nb_bird_tokens
- ).long()
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
-
- for n in range(m.size(0)):
- for i in range(m.size(1)):
- for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
-
- return x
-
- def seq2str(self, seq):
- result = []
- for s in seq:
- result.append("".join([self.token2char[v] for v in s]))
- return result
-
- def save_image(
- self,
- result_dir,
- filename,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- if predicted_prompts is None:
- predicted_prompts = 255
-
- if predicted_answers is None:
- predicted_answers = 255
-
- def add_frame(x, c, margin, bottom=False):
- if bottom:
- h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
- else:
- h, w, di, dj = (
- x.size(2) + 2 * margin,
- x.size(3) + 2 * margin,
- margin,
- margin,
- )
-
- y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
- if type(c) is int:
- y[...] = c
- else:
- c = c.long()[:, None]
- c = (
- (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
- + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
- + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
- )
- y[...] = c[:, :, None, None]
-
- y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
- return y
-
- margin = 4
-
- img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
- h = img_prompts.size(2)
- img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
- img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
- img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
- img_prompts = add_frame(
- img_prompts, c=predicted_prompts, margin=margin, bottom=True
- )
- img_answers = add_frame(
- img_answers, c=predicted_answers, margin=margin, bottom=True
- )
-
- marker_size = 16
-
- separator = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- marker_size,
- ),
- 255,
- )
-
- separator[:, :, 0] = 0
- separator[:, :, h - 1] = 0
-
- for k in range(1, 2 * marker_size - 8):
- i = k - (marker_size - 4)
- j = marker_size - 5 - abs(i)
- separator[:, :, h // 2 - 1 + i, 2 + j] = 0
- separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
-
- img = torch.cat([img_prompts, separator, img_answers], dim=3)
-
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(
- img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
- )
-
- ######################################################################
-
- def nb_token_values(self):
- return len(self.colors)
-
- def generate_prompts_and_answers(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
- frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-
- prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-
- answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-
- # warnings.warn("dirty test with longer answer", RuntimeWarning)
- # answers = torch.cat(
- # [
- # frame_sequences[:, frame_sequences.size(1) // 2 :],
- # frame_sequences[:, frame_sequences.size(1) // 2 :],
- # ],
- # dim=3,
- # ).flatten(1)
-
- return prompts, answers
-
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- self.save_image(
- result_dir,
- filename_prefix + ".png",
- prompts,
- answers,
- predicted_prompts,
- predicted_answers,
- )
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
-
- prompts, answers = sky.generate_prompts_and_answers(4)
-
- predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
- predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
-
- sky.save_quiz_illustrations(
- "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
- )
-
- # start_time = time.perf_counter()
- # token_sequences = sky.generate_token_sequences(nb=64)
- # delay = time.perf_counter() - start_time
- # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
- # print(sky.seq2str(seq[:4]))
-
- # for t in range(len(it[0])):
- # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
- # torchvision.utils.save_image(
- # img.float() / 255.0,
- # f"/tmp/frame_{t:03d}.png",
- # nrow=8,
- # padding=6,
- # pad_value=0,
- # )
-
- # m = (torch.rand(seq.size()) < 0.05).long()
- # seq = (1 - m) * seq + m * 23
-
- # print(seq.size())
- # img = sky.seq2img(token_sequences)
- # print(img.size())
-
- # torchvision.utils.save_image(
- # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
- # )
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Wireworld(problem.Problem):
- colors = torch.tensor(
- [
- [128, 128, 128],
- [128, 128, 255],
- [255, 0, 0],
- [255, 255, 0],
- ]
- )
-
- token_empty = 0
- token_head = 1
- token_tail = 2
- token_conductor = 3
- token_forward = 4
- token_backward = 5
-
- token2char = (
- "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
- )
-
- def __init__(
- self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
- ):
- self.height = height
- self.width = width
- self.nb_objects = nb_objects
- self.nb_walls = nb_walls
- self.speed = speed
- self.nb_iterations = nb_iterations
-
- def direction_tokens(self):
- return self.token_forward, self.token_backward
-
- def generate_frame_sequences(self, nb):
- result = []
- N = 100
- for _ in tqdm.tqdm(
- range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
- ):
- result.append(self.generate_frame_sequences_hard(100))
- return torch.cat(result, dim=0)[:nb]
-
- def generate_frame_sequences_hard(self, nb):
- frame_sequences = []
- nb_frames = (self.nb_iterations - 1) * self.speed + 1
-
- result = torch.full(
- (nb * 4, nb_frames, self.height, self.width),
- self.token_empty,
- )
-
- for n in range(result.size(0)):
- while True:
- i = torch.randint(self.height, (1,))
- j = torch.randint(self.width, (1,))
- v = torch.randint(2, (2,))
- vi = v[0] * (v[1] * 2 - 1)
- vj = (1 - v[0]) * (v[1] * 2 - 1)
- while True:
- if i < 0 or i >= self.height or j < 0 or j >= self.width:
- break
- o = 0
- if i > 0:
- o += (result[n, 0, i - 1, j] == self.token_conductor).long()
- if i < self.height - 1:
- o += (result[n, 0, i + 1, j] == self.token_conductor).long()
- if j > 0:
- o += (result[n, 0, i, j - 1] == self.token_conductor).long()
- if j < self.width - 1:
- o += (result[n, 0, i, j + 1] == self.token_conductor).long()
- if o > 1:
- break
- result[n, 0, i, j] = self.token_conductor
- i += vi
- j += vj
- if (
- result[n, 0] == self.token_conductor
- ).long().sum() > self.width and torch.rand(1) < 0.5:
- break
-
- while True:
- for _ in range(self.height * self.width):
- i = torch.randint(self.height, (1,))
- j = torch.randint(self.width, (1,))
- v = torch.randint(2, (2,))
- vi = v[0] * (v[1] * 2 - 1)
- vj = (1 - v[0]) * (v[1] * 2 - 1)
- if (
- i + vi >= 0
- and i + vi < self.height
- and j + vj >= 0
- and j + vj < self.width
- and result[n, 0, i, j] == self.token_conductor
- and result[n, 0, i + vi, j + vj] == self.token_conductor
- ):
- result[n, 0, i, j] = self.token_head
- result[n, 0, i + vi, j + vj] = self.token_tail
- break
-
- # if torch.rand(1) < 0.75:
- break
-
- weight = torch.full((1, 1, 3, 3), 1.0)
-
- mask = (torch.rand(result[:, 0].size()) < 0.01).long()
- rand = torch.randint(4, mask.size())
- result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
-
- # empty->empty
- # head->tail
- # tail->conductor
- # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
-
- nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
- valid = nb_heads > 0
-
- for l in range(nb_frames - 1):
- nb_head_neighbors = (
- F.conv2d(
- input=(result[:, l] == self.token_head).float()[:, None, :, :],
- weight=weight,
- padding=1,
- )
- .long()
- .squeeze(1)
- )
- mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
- nb_head_neighbors == 2
- ).long()
- result[:, l + 1] = (
- (result[:, l] == self.token_empty).long() * self.token_empty
- + (result[:, l] == self.token_head).long() * self.token_tail
- + (result[:, l] == self.token_tail).long() * self.token_conductor
- + (result[:, l] == self.token_conductor).long()
- * (
- mask_1_or_2_heads * self.token_head
- + (1 - mask_1_or_2_heads) * self.token_conductor
- )
- )
- pred_nb_heads = nb_heads
- nb_heads = (
- (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
- )
- valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
-
- result = result[valid]
-
- result = result[
- :, torch.arange(self.nb_iterations, device=result.device) * self.speed
- ]
-
- i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
- result = result[i]
-
- # print(f"{result.size(0)=} {nb=}")
-
- if result.size(0) < nb:
- # print(result.size(0))
- result = torch.cat(
- [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
- )
-
- return result[:nb]
-
- def generate_token_sequences(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
-
- result = []
-
- for frame_sequence in frame_sequences:
- a = []
- if torch.rand(1) < 0.5:
- for frame in frame_sequence:
- if len(a) > 0:
- a.append(torch.tensor([self.token_forward]))
- a.append(frame.flatten())
- else:
- for frame in reversed(frame_sequence):
- if len(a) > 0:
- a.append(torch.tensor([self.token_backward]))
- a.append(frame.flatten())
-
- result.append(torch.cat(a, dim=0)[None, :])
-
- return torch.cat(result, dim=0)
-
- ######################################################################
-
- def frame2img(self, x, scale=15):
- x = x.reshape(-1, self.height, self.width)
- m = torch.logical_and(x >= 0, x < 4).long()
-
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
-
- for n in range(m.size(0)):
- for i in range(m.size(1)):
- for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
-
- return x
-
- def seq2img(self, seq, scale=15):
- all = [
- self.frame2img(
- seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
- scale,
- )
- ]
-
- separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
-
- t = self.height * self.width
-
- while t < seq.size(1):
- direction_tokens = seq[:, t]
- t += 1
-
- direction_images = self.colors[
- torch.full(
- (direction_tokens.size(0), self.height * scale - 1, scale), 0
- )
- ].permute(0, 3, 1, 2)
-
- for n in range(direction_tokens.size(0)):
- if direction_tokens[n] == self.token_forward:
- for k in range(scale):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- 3 + scale // 2 - abs(k - scale // 2),
- ] = 0
- elif direction_tokens[n] == self.token_backward:
- for k in range(scale):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- 3 + abs(k - scale // 2),
- ] = 0
- else:
- for k in range(2, scale - 2):
- for l in [0, 1]:
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- k,
- ] = 0
- direction_images[
- n,
- :,
- (self.height * scale) // 2 - scale // 2 + k - l,
- scale - 1 - k,
- ] = 0
-
- all += [
- separator,
- direction_images,
- separator,
- self.frame2img(
- seq[:, t : t + self.height * self.width].reshape(
- -1, self.height, self.width
- ),
- scale,
- ),
- ]
-
- t += self.height * self.width
-
- return torch.cat(all, dim=3)
-
- def seq2str(self, seq):
- result = []
- for s in seq:
- result.append("".join([self.token2char[v] for v in s]))
- return result
-
- def save_image(self, input, result_dir, filename):
- img = self.seq2img(input.to("cpu"))
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-
- def save_quizzes(self, input, result_dir, filename_prefix):
- self.save_image(input, result_dir, filename_prefix + ".png")
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
-
- start_time = time.perf_counter()
- frame_sequences = wireworld.generate_frame_sequences(nb=96)
- delay = time.perf_counter() - start_time
- print(f"{frame_sequences.size(0)/delay:02f} seq/s")
-
- # print(wireworld.seq2str(seq[:4]))
-
- for t in range(frame_sequences.size(1)):
- img = wireworld.seq2img(frame_sequences[:, t])
- torchvision.utils.save_image(
- img.float() / 255.0,
- f"/tmp/frame_{t:03d}.png",
- nrow=8,
- padding=6,
- pad_value=0,
- )
-
- # m = (torch.rand(seq.size()) < 0.05).long()
- # seq = (1 - m) * seq + m * 23
-
- wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
- token_sequences = wireworld.generate_token_sequences(32)
- wireworld.save_quizzes(token_sequences, "/tmp", "seq")
- # img = wireworld.seq2img(frame_sequences[:60])
-
- # torchvision.utils.save_image(
- # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
- # )