--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm, os, warnings
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+import problem
+
+
+class Lang(problem.Problem):
+ named_colors = [
+ ("white", [255, 255, 255]),
+ ("red", [255, 0, 0]),
+ ("green", [0, 192, 0]),
+ ("blue", [0, 0, 255]),
+ ("orange", [255, 192, 0]),
+ ("cyan", [0, 255, 255]),
+ ("violet", [255, 0, 255]),
+ ("lightgreen", [192, 255, 192]),
+ ("pink", [255, 192, 192]),
+ ("lightblue", [192, 192, 255]),
+ ("gray", [192, 192, 192]),
+ ]
+
+ def __init__(
+ self,
+ nb_iterations=2,
+ ):
+ self.colors = torch.tensor([c for _, c in self.named_colors])
+ self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
+ self.height = 10
+ self.width = 10
+ self.nb_iterations = nb_iterations
+
+ ######################################################################
+
+ def frame2img(self, x, scale=15):
+ x = x.reshape(x.size(0), self.height, -1)
+ x = self.colors[x].permute(0, 3, 1, 2)
+ s = x.shape
+ x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+ x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+ x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+ x = x[:, :, 1:, 1:]
+
+ return x
+
+ def save_image(
+ self,
+ result_dir,
+ filename,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ if predicted_prompts is None:
+ predicted_prompts = 255
+
+ if predicted_answers is None:
+ predicted_answers = 255
+
+ def add_frame(x, c, margin, bottom=False):
+ if bottom:
+ h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
+ else:
+ h, w, di, dj = (
+ x.size(2) + 2 * margin,
+ x.size(3) + 2 * margin,
+ margin,
+ margin,
+ )
+
+ y = x.new_full((x.size(0), x.size(1), h, w), 0)
+
+ if type(c) is int:
+ y[...] = c
+ else:
+ c = c.long()[:, None]
+ c = c * torch.tensor([0, 0, 0], device=c.device) + (
+ 1 - c
+ ) * torch.tensor([255, 255, 255], device=c.device)
+ y[...] = c[:, :, None, None]
+
+ y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+
+ return y
+
+ margin = 4
+
+ img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+ h = img_prompts.size(2)
+ img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
+
+ img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
+ img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+
+ img_prompts = add_frame(
+ img_prompts, c=predicted_prompts, margin=margin, bottom=True
+ )
+ img_answers = add_frame(
+ img_answers, c=predicted_answers, margin=margin, bottom=True
+ )
+
+ marker_size = 16
+
+ separator = img_prompts.new_full(
+ (
+ img_prompts.size(0),
+ img_prompts.size(1),
+ img_prompts.size(2),
+ marker_size,
+ ),
+ 255,
+ )
+
+ separator[:, :, 0] = 0
+ separator[:, :, h - 1] = 0
+
+ for k in range(1, 2 * marker_size - 8):
+ i = k - (marker_size - 4)
+ j = marker_size - 5 - abs(i)
+ separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+ separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+ img = torch.cat([img_prompts, separator, img_answers], dim=3)
+
+ image_name = os.path.join(result_dir, filename)
+ torchvision.utils.save_image(
+ img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
+ )
+
+ ######################################################################
+
+ def nb_token_values(self):
+ return len(self.colors)
+
+ def rec_coo(self, x):
+ while True:
+ i1, i2 = torch.randint(x.size(0), (2,))
+ if i1 < i2 - 1:
+ break
+ while True:
+ j1, j2 = torch.randint(x.size(1), (2,))
+ if j1 < j2 - 1:
+ break
+ return i1, j1, i2, j2
+
+ def task_red_to_green(self, A, f_A, B, f_B):
+ i1, j1, i2, j2 = self.rec_coo(A)
+ A[i1:i2, j1:j2] = self.name2color["red"]
+ f_A[i1:i2, j1:j2] = self.name2color["green"]
+ i1, j1, i2, j2 = self.rec_coo(B)
+ B[i1:i2, j1:j2] = self.name2color["red"]
+ f_B[i1:i2, j1:j2] = self.name2color["green"]
+
+ def generate_prompts_and_answers(self, nb):
+ prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
+ answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
+ w = self.width
+ for prompt, answer in zip(prompts, answers):
+ self.task_red_to_green(
+ prompt[:, 0 * w : 1 * w],
+ prompt[:, 1 * w : 2 * w],
+ prompt[:, 2 * w : 3 * w],
+ answer,
+ )
+ return prompts, answers
+
+ def save_quizzes(
+ self,
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ self.save_image(
+ result_dir,
+ filename_prefix + ".png",
+ prompts,
+ answers,
+ predicted_prompts,
+ predicted_answers,
+ )
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import time
+
+ lang = Lang(nb_iterations=4)
+
+ prompts, answers = lang.generate_prompts_and_answers(24)
+
+ # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+ # predicted_answers = torch.rand(answers.size(0)) < 0.5
+
+ lang.save_quizzes(
+ "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers
+ )
+
+ # start_time = time.perf_counter()
+ # token_sequences = lang.generate_token_sequences(nb=64)
+ # delay = time.perf_counter() - start_time
+ # print(f"{token_sequences.size(0)/delay:02f} seq/s")
+
+ # print(lang.seq2str(seq[:4]))
+
+ # for t in range(len(it[0])):
+ # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
+ # torchvision.utils.save_image(
+ # img.float() / 255.0,
+ # f"/tmp/frame_{t:03d}.png",
+ # nrow=8,
+ # padding=6,
+ # pad_value=0,
+ # )
+
+ # m = (torch.rand(seq.size()) < 0.05).long()
+ # seq = (1 - m) * seq + m * 23
+
+ # print(seq.size())
+ # img = lang.seq2img(token_sequences)
+ # print(img.size())
+
+ # torchvision.utils.save_image(
+ # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
+ # )