d53386c5761e917ebd9592e79875c50bc698c811
[culture.git] / lang.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Lang(problem.Problem):
21     named_colors = [
22         ("white", [255, 255, 255]),
23         ("red", [255, 0, 0]),
24         ("green", [0, 192, 0]),
25         ("blue", [0, 0, 255]),
26         ("orange", [255, 192, 0]),
27         ("cyan", [0, 255, 255]),
28         ("violet", [255, 0, 255]),
29         ("lightgreen", [192, 255, 192]),
30         ("pink", [255, 192, 192]),
31         ("lightblue", [192, 192, 255]),
32         ("gray", [192, 192, 192]),
33     ]
34
35     def __init__(
36         self,
37         nb_iterations=2,
38     ):
39         self.colors = torch.tensor([c for _, c in self.named_colors])
40         self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
41         self.height = 10
42         self.width = 10
43         self.nb_iterations = nb_iterations
44
45     ######################################################################
46
47     def frame2img(self, x, scale=15):
48         x = x.reshape(x.size(0), self.height, -1)
49         x = self.colors[x].permute(0, 3, 1, 2)
50         s = x.shape
51         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
52         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
53
54         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
55         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
56         x = x[:, :, 1:, 1:]
57
58         return x
59
60     def save_image(
61         self,
62         result_dir,
63         filename,
64         prompts,
65         answers,
66         predicted_prompts=None,
67         predicted_answers=None,
68     ):
69         if predicted_prompts is None:
70             predicted_prompts = 255
71
72         if predicted_answers is None:
73             predicted_answers = 255
74
75         def add_frame(x, c, margin, bottom=False):
76             if bottom:
77                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
78             else:
79                 h, w, di, dj = (
80                     x.size(2) + 2 * margin,
81                     x.size(3) + 2 * margin,
82                     margin,
83                     margin,
84                 )
85
86             y = x.new_full((x.size(0), x.size(1), h, w), 0)
87
88             if type(c) is int:
89                 y[...] = c
90             else:
91                 c = c.long()[:, None]
92                 c = c * torch.tensor([0, 0, 0], device=c.device) + (
93                     1 - c
94                 ) * torch.tensor([255, 255, 255], device=c.device)
95                 y[...] = c[:, :, None, None]
96
97             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
98
99             return y
100
101         margin = 4
102
103         img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
104         h = img_prompts.size(2)
105         img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
106
107         img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
108         img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
109
110         img_prompts = add_frame(
111             img_prompts, c=predicted_prompts, margin=margin, bottom=True
112         )
113         img_answers = add_frame(
114             img_answers, c=predicted_answers, margin=margin, bottom=True
115         )
116
117         marker_size = 16
118
119         separator = img_prompts.new_full(
120             (
121                 img_prompts.size(0),
122                 img_prompts.size(1),
123                 img_prompts.size(2),
124                 marker_size,
125             ),
126             255,
127         )
128
129         separator[:, :, 0] = 0
130         separator[:, :, h - 1] = 0
131
132         for k in range(1, 2 * marker_size - 8):
133             i = k - (marker_size - 4)
134             j = marker_size - 5 - abs(i)
135             separator[:, :, h // 2 - 1 + i, 2 + j] = 0
136             separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
137
138         img = torch.cat([img_prompts, separator, img_answers], dim=3)
139
140         image_name = os.path.join(result_dir, filename)
141         torchvision.utils.save_image(
142             img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
143         )
144
145     ######################################################################
146
147     def nb_token_values(self):
148         return len(self.colors)
149
150     def rec_coo(self, x):
151         while True:
152             i1, i2 = torch.randint(x.size(0), (2,))
153             if i1 < i2 - 1:
154                 break
155         while True:
156             j1, j2 = torch.randint(x.size(1), (2,))
157             if j1 < j2 - 1:
158                 break
159         return i1, j1, i2, j2
160
161     def task_red_to_green(self, A, f_A, B, f_B):
162         i1, j1, i2, j2 = self.rec_coo(A)
163         A[i1:i2, j1:j2] = self.name2color["red"]
164         f_A[i1:i2, j1:j2] = self.name2color["green"]
165         i1, j1, i2, j2 = self.rec_coo(B)
166         B[i1:i2, j1:j2] = self.name2color["red"]
167         f_B[i1:i2, j1:j2] = self.name2color["green"]
168
169     def generate_prompts_and_answers(self, nb):
170         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
171         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
172         w = self.width
173         for prompt, answer in zip(prompts, answers):
174             self.task_red_to_green(
175                 prompt[:, 0 * w : 1 * w],
176                 prompt[:, 1 * w : 2 * w],
177                 prompt[:, 2 * w : 3 * w],
178                 answer,
179             )
180         return prompts, answers
181
182     def save_quizzes(
183         self,
184         result_dir,
185         filename_prefix,
186         prompts,
187         answers,
188         predicted_prompts=None,
189         predicted_answers=None,
190     ):
191         self.save_image(
192             result_dir,
193             filename_prefix + ".png",
194             prompts,
195             answers,
196             predicted_prompts,
197             predicted_answers,
198         )
199
200
201 ######################################################################
202
203 if __name__ == "__main__":
204     import time
205
206     lang = Lang(nb_iterations=4)
207
208     prompts, answers = lang.generate_prompts_and_answers(24)
209
210     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
211     # predicted_answers = torch.rand(answers.size(0)) < 0.5
212
213     lang.save_quizzes(
214         "/tmp", "test", prompts, answers  # , predicted_prompts, predicted_answers
215     )
216
217     # start_time = time.perf_counter()
218     # token_sequences = lang.generate_token_sequences(nb=64)
219     # delay = time.perf_counter() - start_time
220     # print(f"{token_sequences.size(0)/delay:02f} seq/s")
221
222     # print(lang.seq2str(seq[:4]))
223
224     # for t in range(len(it[0])):
225     # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
226     # torchvision.utils.save_image(
227     # img.float() / 255.0,
228     # f"/tmp/frame_{t:03d}.png",
229     # nrow=8,
230     # padding=6,
231     # pad_value=0,
232     # )
233
234     # m = (torch.rand(seq.size()) < 0.05).long()
235     # seq = (1 - m) * seq + m * 23
236
237     # print(seq.size())
238     # img = lang.seq2img(token_sequences)
239     # print(img.size())
240
241     # torchvision.utils.save_image(
242     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
243     # )