1e6ed4d26659256637b2bb3d874ce2997a94b814
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Sky(problem.Problem):
21     colors = torch.tensor(
22         [
23             [255, 255, 255],
24             [255, 0, 0],
25             [0, 192, 0],
26             [0, 0, 255],
27             [255, 192, 0],
28             [0, 255, 255],
29             [255, 0, 255],
30             [192, 255, 192],
31             [255, 192, 192],
32             [192, 192, 255],
33             [192, 192, 192],
34         ]
35     )
36
37     token_background = 0
38     first_bird_token = 1
39     nb_bird_tokens = colors.size(0) - 1
40     token_forward = first_bird_token + nb_bird_tokens
41     token_backward = token_forward + 1
42
43     token2char = (
44         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
45     )
46
47     def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
48         self.height = height
49         self.width = width
50         self.nb_birds = nb_birds
51         self.nb_iterations = nb_iterations
52
53     def direction_tokens(self):
54         return self.token_forward, self.token_backward
55
56     def generate_seq(self, nb, return_iterations=False):
57         pairs = []
58         kept_iterations = []
59
60         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
61             while True:
62                 iterations = []
63
64                 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
65
66                 i, j, vi, vj = (
67                     torch.empty(self.nb_birds, dtype=torch.int64),
68                     torch.empty(self.nb_birds, dtype=torch.int64),
69                     torch.empty(self.nb_birds, dtype=torch.int64),
70                     torch.empty(self.nb_birds, dtype=torch.int64),
71                 )
72
73                 col = (
74                     torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
75                     .sort()
76                     .values
77                     + 1
78                 )
79
80                 for n in range(self.nb_birds):
81                     c = col[n]
82
83                     while True:
84                         i[n], j[n] = (
85                             torch.randint(self.height, (1,))[0],
86                             torch.randint(self.width, (1,))[0],
87                         )
88                         vm = torch.randint(4, (1,))[0]
89                         vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
90                         if (
91                             i[n] - vi[n] >= 0
92                             and i[n] - vi[n] < self.height
93                             and j[n] - vj[n] >= 0
94                             and j[n] - vj[n] < self.width
95                             and f_start[i[n], j[n]] == 0
96                             and f_start[i[n] - vi[n], j[n]] == 0
97                             and f_start[i[n], j[n] - vj[n]] == 0
98                         ):
99                             break
100
101                     f_start[i[n], j[n]] = c
102                     f_start[i[n] - vi[n], j[n]] = c
103                     f_start[i[n], j[n] - vj[n]] = c
104
105                 f_end = f_start.clone()
106
107                 for l in range(self.nb_iterations):
108                     iterations.append(f_end.clone())
109                     f_end[...] = 0
110                     nb_collisions = 0
111                     for n in range(self.nb_birds):
112                         c = col[n]
113
114                         pi, pj, pvi, pvj = (
115                             i[n].item(),
116                             j[n].item(),
117                             vi[n].item(),
118                             vj[n].item(),
119                         )
120
121                         if (i[n] == 0 and vi[n] == -1) or (
122                             i[n] == self.height - 1 and vi[n] == 1
123                         ):
124                             vi[n] = -vi[n]
125                         if (j[n] == 0 and vj[n] == -1) or (
126                             j[n] == self.width - 1 and vj[n] == 1
127                         ):
128                             vj[n] = -vj[n]
129
130                         i[n] += vi[n]
131                         j[n] += vj[n]
132
133                         if not (
134                             f_end[i[n], j[n]] == 0
135                             and f_end[i[n] - vi[n], j[n]] == 0
136                             and f_end[i[n], j[n] - vj[n]] == 0
137                         ):
138                             nb_collisions += 1
139
140                         f_end[i[n], j[n]] = c
141                         f_end[i[n] - vi[n], j[n]] = c
142                         f_end[i[n], j[n] - vj[n]] = c
143
144                 iterations.append(f_end.clone())
145
146                 if nb_collisions == 0:
147                     break
148
149             kept_iterations.append(iterations)
150             pairs.append((f_start, f_end))
151
152         result = []
153         for p in pairs:
154             if torch.rand(1) < 0.5:
155                 result.append(
156                     torch.cat(
157                         [
158                             p[0].flatten(),
159                             torch.tensor([self.token_forward]),
160                             p[1].flatten(),
161                         ],
162                         dim=0,
163                     )[None, :]
164                 )
165             else:
166                 result.append(
167                     torch.cat(
168                         [
169                             p[1].flatten(),
170                             torch.tensor([self.token_backward]),
171                             p[0].flatten(),
172                         ],
173                         dim=0,
174                     )[None, :]
175                 )
176
177         if return_iterations:
178             # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
179             return torch.cat(result, dim=0), kept_iterations
180         else:
181             return torch.cat(result, dim=0)
182
183     ######################################################################
184
185     def generate_seq_old(
186         self,
187         nb,
188     ):
189         pairs = []
190
191         for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
192             f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
193             f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
194             n = torch.arange(f_start.size(0))
195
196             for c in (
197                 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
198                     : self.nb_birds
199                 ]
200                 .sort()
201                 .values
202             ):
203                 i, j = (
204                     torch.randint(self.height - 2, (1,))[0] + 1,
205                     torch.randint(self.width - 2, (1,))[0] + 1,
206                 )
207                 vm = torch.randint(4, (1,))[0]
208                 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
209                     2 * (vm % 2) - 1
210                 )
211
212                 f_start[i, j] = c
213                 f_start[i - vi, j - vj] = c
214                 f_start[i + vj, j - vi] = c
215                 f_start[i - vj, j + vi] = c
216
217                 for l in range(self.nb_iterations):
218                     i += vi
219                     j += vj
220                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
221                         i -= vi
222                         j -= vj
223                         vi, vj = -vi, -vj
224                         i += vi
225                         j += vj
226
227                 f_end[i, j] = c
228                 f_end[i - vi, j - vj] = c
229                 f_end[i + vj, j - vi] = c
230                 f_end[i - vj, j + vi] = c
231
232             pairs.append((f_start, f_end))
233
234         result = []
235         for p in pairs:
236             if torch.rand(1) < 0.5:
237                 result.append(
238                     torch.cat(
239                         [
240                             p[0].flatten(),
241                             torch.tensor([self.token_forward]),
242                             p[1].flatten(),
243                         ],
244                         dim=0,
245                     )[None, :]
246                 )
247             else:
248                 result.append(
249                     torch.cat(
250                         [
251                             p[1].flatten(),
252                             torch.tensor([self.token_backward]),
253                             p[0].flatten(),
254                         ],
255                         dim=0,
256                     )[None, :]
257                 )
258
259         return torch.cat(result, dim=0)
260
261     def frame2img(self, x, upscale=15):
262         x = x.reshape(-1, self.height, self.width)
263         m = torch.logical_and(
264             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
265         ).long()
266         x = self.colors[x * m].permute(0, 3, 1, 2)
267         s = x.shape
268         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
269         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
270
271         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
272         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
273         x = x[:, :, 1:, 1:]
274
275         for n in range(m.size(0)):
276             for i in range(m.size(1)):
277                 for j in range(m.size(2)):
278                     if m[n, i, j] == 0:
279                         for k in range(2, upscale - 2):
280                             x[n, :, i * upscale + k, j * upscale + k] = 0
281                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
282
283         return x
284
285     def seq2img(self, seq, upscale=15):
286         f_first = seq[:, : self.height * self.width].reshape(
287             -1, self.height, self.width
288         )
289         f_second = seq[:, self.height * self.width + 1 :].reshape(
290             -1, self.height, self.width
291         )
292         direction = seq[:, self.height * self.width]
293
294         direction_symbol = torch.full(
295             (direction.size(0), self.height * upscale - 1, upscale), 0
296         )
297         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
298         separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
299
300         for n in range(direction_symbol.size(0)):
301             if direction[n] == self.token_forward:
302                 for k in range(upscale):
303                     direction_symbol[
304                         n,
305                         :,
306                         (self.height * upscale) // 2 - upscale // 2 + k,
307                         3 + upscale // 2 - abs(k - upscale // 2),
308                     ] = 0
309             elif direction[n] == self.token_backward:
310                 for k in range(upscale):
311                     direction_symbol[
312                         n,
313                         :,
314                         (self.height * upscale) // 2 - upscale // 2 + k,
315                         3 + abs(k - upscale // 2),
316                     ] = 0
317             else:
318                 for k in range(2, upscale - 2):
319                     direction_symbol[
320                         n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
321                     ] = 0
322                     direction_symbol[
323                         n,
324                         :,
325                         (self.height * upscale) // 2 - upscale // 2 + k,
326                         upscale - 1 - k,
327                     ] = 0
328
329         return torch.cat(
330             [
331                 self.frame2img(f_first, upscale),
332                 separator,
333                 direction_symbol,
334                 separator,
335                 self.frame2img(f_second, upscale),
336             ],
337             dim=3,
338         )
339
340     def seq2str(self, seq):
341         result = []
342         for s in seq:
343             result.append("".join([self.token2char[v] for v in s]))
344         return result
345
346     def save_image(self, input, result_dir, filename):
347         img = self.seq2img(input.to("cpu"))
348         image_name = os.path.join(result_dir, filename)
349         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
350
351     def save_quizzes(self, input, result_dir, filename_prefix):
352         self.save_image(input, result_dir, filename_prefix + ".png")
353
354
355 ######################################################################
356
357 if __name__ == "__main__":
358     import time
359
360     sky = Sky(height=6, width=8, nb_iterations=100)
361
362     start_time = time.perf_counter()
363     seq, it = sky.generate_seq(nb=64, return_iterations=True)
364     delay = time.perf_counter() - start_time
365     print(f"{seq.size(0)/delay:02f} samples/s")
366
367     print(sky.seq2str(seq[:4]))
368
369     for t in range(len(it[0])):
370         img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
371         torchvision.utils.save_image(
372             img.float() / 255.0,
373             f"/tmp/frame_{t:03d}.png",
374             nrow=8,
375             padding=6,
376             pad_value=0,
377         )
378
379     # m = (torch.rand(seq.size()) < 0.05).long()
380     # seq = (1 - m) * seq + m * 23
381
382     img = sky.seq2img(seq)
383     print(img.size())
384
385     torchvision.utils.save_image(
386         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
387     )