--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm, os, warnings, cairo, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+def text_img(height, width, text):
+ pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+ surface = cairo.ImageSurface.create_for_data(
+ pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+ )
+
+ ctx = cairo.Context(surface)
+ ctx.set_source_rgb(0, 0, 0)
+ ctx.set_font_size(16)
+ ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+ y = None
+ for line in text.split("\n"):
+ xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+ if y is None:
+ y = height * 1.5
+ x = height * 0.5
+
+ ctx.move_to(x, y)
+ ctx.show_text(line)
+ y += height * 1.5
+
+ ctx.stroke()
+
+ return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
+import problem
+
+
+class Grids(problem.Problem):
+ grid_gray = 240
+ thickness = 1
+ background_gray = 240
+ dots = False
+
+ named_colors = [
+ ("white", [background_gray, background_gray, background_gray]),
+ # ("white", [224, 224, 224]),
+ ("red", [255, 0, 0]),
+ ("green", [0, 160, 0]),
+ ("blue", [0, 0, 255]),
+ ("yellow", [255, 224, 0]),
+ ("cyan", [0, 255, 255]),
+ ("violet", [224, 128, 255]),
+ ("lightgreen", [160, 255, 160]),
+ ("brown", [165, 42, 42]),
+ ("lightblue", [192, 192, 255]),
+ ("gray", [128, 128, 128]),
+ ]
+
+ def pure_noise(self, nb, device):
+ result = torch.randint(
+ self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
+ )
+ return result
+
+ def trivial(self, quizzes):
+ S = self.height * self.width
+ assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
+ a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+ dim=1
+ ).values
+
+ def text2quiz(self, t):
+ chr2col = [
+ (".", "white"),
+ ("r", "red"),
+ ("g", "green"),
+ ("b", "blue"),
+ ("y", "yellow"),
+ ("c", "cyan"),
+ ("v", "violet"),
+ ("l", "lightgreen"),
+ ("o", "brown"),
+ ("l", "lightblue"),
+ ("a", "gray"),
+ ]
+
+ col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+ chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+ t = re.sub(r"#.*\n", "", t).strip()
+ l = t.replace("\n\n", ";").split(";")
+
+ result = []
+
+ for t in l:
+ t = "".join(t.replace("\n", " ").strip().split(" "))
+ t = torch.tensor([chr2tok[c] for c in t])
+ t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+ t = torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.token_A],
+ [self.token_f_A],
+ [self.token_B],
+ [self.token_f_B],
+ ]
+ ),
+ t,
+ ],
+ dim=1,
+ )
+ result.append(t.flatten()[None, :])
+
+ return torch.cat(result, dim=0)
+
+ def __init__(
+ self,
+ max_nb_cached_chunks=None,
+ chunk_size=None,
+ nb_threads=-1,
+ tasks=None,
+ ):
+ self.colors = torch.tensor([c for _, c in self.named_colors])
+
+ self.nb_colors = len(self.colors)
+
+ self.nb_rec_max = 5
+ self.rfree = torch.tensor([])
+
+ self.height = 12
+ self.width = 20
+ self.seq_len = 4 * self.height * self.width
+
+ self.cache_rec_coo = {}
+
+ all_tasks = [
+ self.task_replace_color,
+ self.task_translate,
+ self.task_grow,
+ self.task_frame,
+ ]
+
+ if tasks is None:
+ self.all_tasks = all_tasks
+ else:
+ self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
+ super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
+
+ ######################################################################
+
+ def vocabulary_size(self):
+ return self.nb_colors
+
+ def grid2img(self, x, scale=15, grids=True):
+ m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+ y = self.colors[x * m].permute(0, 3, 1, 2)
+ s = y.shape
+ y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+ if grids:
+ for t in range(self.thickness):
+ y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+ y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+ if self.dots:
+ z = y.reshape(
+ y.size(0),
+ y.size(1),
+ y.size(2) // scale,
+ scale,
+ y.size(3) // scale,
+ scale,
+ )
+ z = z[
+ :,
+ :,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ :,
+ scale // 2 - 1 : scale // 2 + 2,
+ ]
+ zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+ z[...] = zz * self.grid_gray + (zz == False) * z
+
+ for n in range(m.size(0)):
+ for i in range(m.size(1)):
+ for j in range(m.size(2)):
+ if x[n, i, j] >= self.nb_colors:
+ # for k in range(3, scale - 2):
+ c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+ # y[n, :, i * scale + k, j * scale + k] = c
+ # y[n, :, i * scale + k, j * scale + scale - k] = c
+ y[
+ n,
+ :,
+ i * scale + 3 : i * scale + scale - 2,
+ j * scale + 3 : j * scale + scale - 2,
+ ] = c
+
+ y = y[:, :, 1:, 1:]
+
+ return y
+
+ def add_frame(self, img, colors, thickness):
+ if thickness > 0:
+ result = img.new(
+ img.size(0),
+ img.size(1),
+ img.size(2) + 2 * thickness,
+ img.size(3) + 2 * thickness,
+ )
+
+ result[...] = colors[:, :, None, None]
+ result[:, :, thickness:-thickness, thickness:-thickness] = img
+ else:
+ result = img
+
+ return result
+
+ def save_quizzes_as_image(
+ self,
+ result_dir,
+ filename,
+ quizzes,
+ predicted_parts=None,
+ correct_parts=None,
+ comments=None,
+ comment_height=48,
+ nrow=4,
+ grids=True,
+ margin=12,
+ delta=False,
+ delta_highlight=False,
+ ):
+ quizzes = quizzes.to("cpu")
+
+ S = self.height * self.width
+
+ A, f_A, B, f_B = (
+ quizzes.reshape(quizzes.size(0), 4, S)
+ .reshape(quizzes.size(0), 4, self.height, self.width)
+ .permute(1, 0, 2, 3)
+ )
+
+ frame, white, gray, green, red = torch.tensor(
+ [
+ [self.grid_gray, self.grid_gray, self.grid_gray],
+ [255, 255, 255],
+ [200, 200, 200],
+ [0, 255, 0],
+ [255, 0, 0],
+ ],
+ device=quizzes.device,
+ )
+
+ thickness = self.thickness
+
+ if delta:
+ u = (A != f_A).long()
+ img_delta_A = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_A
+ )
+ u = (B != f_B).long()
+ img_delta_B = self.add_frame(
+ self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+ img_delta_B
+ )
+
+ img_A = self.add_frame(
+ self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_f_A = self.add_frame(
+ self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_B = self.add_frame(
+ self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+ )
+ img_f_B = self.add_frame(
+ self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
+ )
+
+ if delta_highlight:
+ q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+ img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
+
+ # predicted_parts Nx4
+ # correct_parts Nx4
+
+ if predicted_parts is None:
+ colors = white[None, None, :].expand(-1, 4, -1)
+ else:
+ predicted_parts = predicted_parts.to("cpu")
+ if correct_parts is None:
+ colors = (
+ predicted_parts[:, :, None] * gray[None, None, :]
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+ )
+ else:
+ correct_parts = correct_parts.to("cpu")
+ colors = (
+ predicted_parts[:, :, None]
+ * (
+ (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+ + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+ + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+ )
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+ )
+
+ separation = 6
+
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
+
+ img_A = self.add_frame(img_A, white[None, :], thickness=2)
+ img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+ img_B = self.add_frame(img_B, white[None, :], thickness=2)
+ img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
+
+ if delta:
+ img_delta_A = self.add_frame(
+ img_delta_A, colors[:, 0], thickness=separation
+ )
+ img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+ img_delta_B = self.add_frame(
+ img_delta_B, colors[:, 0], thickness=separation
+ )
+ img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+ img = torch.cat(
+ [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+ )
+ else:
+ img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+
+ if comments is not None:
+ comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+ comment_img = torch.cat(comment_img, dim=0)
+ img = torch.cat([img, comment_img], dim=2)
+
+ image_name = os.path.join(result_dir, filename)
+
+ torchvision.utils.save_image(
+ img.float() / 255.0,
+ image_name,
+ nrow=nrow,
+ padding=margin * 4,
+ pad_value=1.0,
+ )
+
+ ######################################################################
+
+ # @torch.compile
+ def rec_coo(
+ self,
+ nb_rec,
+ min_height=3,
+ min_width=3,
+ surface_max=None,
+ ):
+ if surface_max is None:
+ surface_max = self.height * self.width // 4
+
+ signature = (nb_rec, min_height, min_width, surface_max)
+
+ try:
+ return self.cache_rec_coo[signature].pop()
+ except IndexError:
+ pass
+ except KeyError:
+ pass
+
+ N = 10000
+ while True:
+ while True:
+ i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+ j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
+ i[:, 1] += 1
+ j[:, 1] += 1
+ big_enough = (
+ (i[:, 1] >= i[:, 0] + min_height)
+ & (j[:, 1] >= j[:, 0] + min_height)
+ & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
+ )
+
+ i, j = i[big_enough], j[big_enough]
+
+ n = i.size(0) - i.size(0) % nb_rec
+
+ if n > 0:
+ break
+
+ i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+ j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+ if i.size(0) > 1:
+ break
+
+ self.cache_rec_coo[signature] = [
+ [
+ (
+ i[n, k, 0].item(),
+ j[n, k, 0].item(),
+ i[n, k, 1].item(),
+ j[n, k, 1].item(),
+ )
+ for k in range(nb_rec)
+ ]
+ for n in range(i.size(0))
+ ]
+
+ return self.cache_rec_coo[signature].pop()
+
+ ######################################################################
+
+ def task_replace_color(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+
+ def task_translate(self, A, f_A, B, f_B):
+ while True:
+ di, dj = torch.randint(3, (2,)) - 1
+ if di.abs() + dj.abs() > 0:
+ break
+
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec)
+ i1, j1, i2, j2 = r[nb_rec - 1]
+ if (
+ i1 + di >= 0
+ and i2 + di < X.size(0)
+ and j1 + dj >= 0
+ and j2 + dj < X.size(1)
+ ):
+ break
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ if n == nb_rec - 1:
+ f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
+
+ def task_grow(self, A, f_A, B, f_B):
+ di, dj = torch.randint(2, (2,)) * 2 - 1
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+ direction = torch.randint(2, (1,)).item()
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec)
+ i1, j1, i2, j2 = r[nb_rec - 1]
+ if i1 + 3 < i2 and j1 + 3 < j2:
+ break
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ if n == nb_rec - 1:
+ if direction == 0:
+ X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ else:
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+ else:
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
+ # @torch.compile
+ def task_frame(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ if n == nb_rec - 1:
+ f_X[i1:i2, j1] = c[n]
+ f_X[i1:i2, j2 - 1] = c[n]
+ f_X[i1, j1:j2] = c[n]
+ f_X[i2 - 1, j1:j2] = c[n]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
+
+ ######################################################################
+
+ def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+ quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+ quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+ quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+ quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
+
+ return quizzes
+
+ def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+ S = self.height * self.width
+
+ if tasks is None:
+ tasks = self.all_tasks
+
+ quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
+
+ if progress_bar:
+ quizzes = tqdm.tqdm(
+ quizzes,
+ dynamic_ncols=True,
+ desc="world quizzes generation",
+ total=quizzes.size(0),
+ )
+
+ for quiz in quizzes:
+ q = quiz.reshape(4, self.height, self.width)
+ q[...] = 0
+ A, f_A, B, f_B = q
+ task = tasks[torch.randint(len(tasks), (1,)).item()]
+ task(A, f_A, B, f_B)
+
+ return quizzes
+
+ def save_some_examples(self, result_dir, prefix=""):
+ nb, nrow = 256, 8
+ for t in self.all_tasks:
+ print(t.__name__)
+ quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+ self.save_quizzes_as_image(
+ result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
+ )
+
+ def detect_rectangles(self, q1, q2):
+ c = torch.arange(self.nb_colors)
+ I = torch.arange(self.height)[None, :, None]
+ J = torch.arange(self.width)[None, :, None]
+
+ def corners(q):
+ q = q.reshape(-1, self.height, self.width)
+ a = (q[:, :, :, None] == c[None, None, None, :]).long()
+ mi = a.max(dim=2).values
+ i = mi * I
+ i1 = (i + (1 - mi) * q.size(1)).min(dim=1).values
+ i2 = (i + (1 - mi) * (-1)).max(dim=1).values + 1
+ mj = a.max(dim=1).values
+ j = mj * J
+ j1 = (j + (1 - mj) * q.size(2)).min(dim=1).values
+ j2 = (j + (1 - mj) * (-1)).max(dim=1).values + 1
+ m = (
+ ((I > i1[:, None, :]) & (I < i2[:, None, :] - 1))[:, :, None, :]
+ & ((J > j1[:, None, :]) & (J < j2[:, None, :] - 1))[:, None, :, :]
+ ).long()
+ f = ((a * m).long().sum(dim=(1, 2)) > 0).long()
+ return i1, i2, j1, j2, f
+
+ q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
+ q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+ u1, u2 = 0, 0
+
+ for _ in range(10):
+ r1 = q.new_zeros(q1.size(0), self.height, self.width)
+ r2 = q.new_zeros(q1.size(0), self.height, self.width)
+
+ m1 = (
+ ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :]))[:, :, None, :]
+ & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[:, None, :, :]
+ ).long()
+
+ f1 = (
+ (
+ ((I == q1_i1[:, None, :]) | (I == q1_i2[:, None, :] - 1))[
+ :, :, None, :
+ ]
+ & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[
+ :, None, :, :
+ ]
+ )
+ | (
+ ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :] - 1))[
+ :, :, None, :
+ ]
+ & ((J == q1_j1[:, None, :]) | (J == q1_j2[:, None, :] - 1))[
+ :, None, :, :
+ ]
+ )
+ ).long()
+
+ r2 = q.new_zeros(q2.size(0), self.height, self.width)
+
+ m2 = (
+ ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :]))[:, :, None, :]
+ & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[:, None, :, :]
+ ).long()
+
+ f2 = (
+ (
+ ((I == q2_i1[:, None, :]) | (I == q2_i2[:, None, :] - 1))[
+ :, :, None, :
+ ]
+ & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[
+ :, None, :, :
+ ]
+ )
+ | (
+ ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :] - 1))[
+ :, :, None, :
+ ]
+ & ((J == q2_j1[:, None, :]) | (J == q2_j2[:, None, :] - 1))[
+ :, None, :, :
+ ]
+ )
+ ).long()
+
+ for c in torch.randperm(self.nb_colors - 1) + 1:
+ r1[...] = q1_f[:, None, None, c] * (
+ m1[:, :, :, c] * c + (1 - m1[:, :, :, c]) * r1
+ ) + (1 - q1_f[:, None, None, c]) * (
+ f1[:, :, :, c] * c + (1 - f1[:, :, :, c]) * r1
+ )
+
+ r2[...] = q2_f[:, None, None, c] * (
+ m2[:, :, :, c] * c + (1 - m2[:, :, :, c]) * r2
+ ) + (1 - q2_f[:, None, None, c]) * (
+ f2[:, :, :, c] * c + (1 - f2[:, :, :, c]) * r2
+ )
+
+ match = (
+ (q1 == r1.flatten(1)).min(dim=1).values
+ & (q2 == r2.flatten(1)).min(dim=1).values
+ ).long()[:, None, None]
+ u1 = (1 - match) * u1 + match * r1
+ u2 = (1 - match) * u2 + match * r2
+
+ return u1.flatten(1), u2.flatten(1)
+
+ # o = F.one_hot(q * (1 - m)).sum(dim=1)
+ # print(o)
+ # print(o.sort(dim=1, descending=True))
+ # c = N x nb_col x 4
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import time
+
+ grids = Grids()
+
+ nb, nrow = 64, 4
+ nb_rows = 12
+
+ # c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
+ # c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
+
+ # grids.save_quizzes_as_image(
+ # "/tmp",
+ # "c_quizzes.png",
+ # c_quizzes,
+ # delta=True,
+ # nrow=nrow,
+ # margin=10,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ # )
+
+ w_quizzes = grids.generate_w_quizzes_(
+ 16,
+ tasks=[
+ grids.task_replace_color,
+ grids.task_translate,
+ grids.task_grow,
+ grids.task_frame,
+ ],
+ )
+
+ q = w_quizzes.reshape(-1, 4, w_quizzes.size(1) // 4)
+ r = q.new_zeros(q.size())
+ r[:, 0], r[:, 1] = grids.detect_rectangles(q[:, 0], q[:, 1])
+ r[:, 2], r[:, 3] = grids.detect_rectangles(q[:, 2], q[:, 3])
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ "q.png",
+ q.flatten(1),
+ # delta=True,
+ nrow=nrow,
+ margin=10,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ )
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ "r.png",
+ r.flatten(1),
+ # delta=True,
+ nrow=nrow,
+ margin=10,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ )
+
+ exit(0)
+
+ q = grids.text2quiz(
+ """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+ )
+
+ grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)