3458d853110096b041ed9f64340a9d1a48f7777f
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [255, 0, 0],
22         [0, 192, 0],
23         [0, 0, 255],
24         [255, 192, 0],
25         [0, 255, 255],
26         [255, 0, 255],
27         [192, 255, 192],
28         [255, 192, 192],
29         [192, 192, 255],
30         [192, 192, 192],
31     ]
32 )
33
34 token_background = 0
35 first_bird_token = 1
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
39
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
41
42
43 class Sky:
44     def __init__(self, height, width):
45         self.heigh = heigh
46         self.width = width
47
48     def generate_seq(
49         nb, height, width, nb_birds=3, nb_iterations=2, return_iterations=False
50     ):
51         pairs = []
52         kept_iterations = []
53
54         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
55             while True:
56                 iterations = []
57
58                 f_start = torch.zeros(height, width, dtype=torch.int64)
59
60                 i, j, vi, vj = (
61                     torch.empty(nb_birds, dtype=torch.int64),
62                     torch.empty(nb_birds, dtype=torch.int64),
63                     torch.empty(nb_birds, dtype=torch.int64),
64                     torch.empty(nb_birds, dtype=torch.int64),
65                 )
66
67                 col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
68
69                 for n in range(nb_birds):
70                     c = col[n]
71
72                     while True:
73                         i[n], j[n] = (
74                             torch.randint(height, (1,))[0],
75                             torch.randint(width, (1,))[0],
76                         )
77                         vm = torch.randint(4, (1,))[0]
78                         vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
79                         if (
80                             i[n] - vi[n] >= 0
81                             and i[n] - vi[n] < height
82                             and j[n] - vj[n] >= 0
83                             and j[n] - vj[n] < width
84                             and f_start[i[n], j[n]] == 0
85                             and f_start[i[n] - vi[n], j[n]] == 0
86                             and f_start[i[n], j[n] - vj[n]] == 0
87                         ):
88                             break
89
90                     f_start[i[n], j[n]] = c
91                     f_start[i[n] - vi[n], j[n]] = c
92                     f_start[i[n], j[n] - vj[n]] = c
93
94                 f_end = f_start.clone()
95
96                 for l in range(nb_iterations):
97                     iterations.append(f_end.clone())
98                     f_end[...] = 0
99                     nb_collisions = 0
100                     for n in range(nb_birds):
101                         c = col[n]
102
103                         pi, pj, pvi, pvj = (
104                             i[n].item(),
105                             j[n].item(),
106                             vi[n].item(),
107                             vj[n].item(),
108                         )
109
110                         if (i[n] == 0 and vi[n] == -1) or (
111                             i[n] == height - 1 and vi[n] == 1
112                         ):
113                             vi[n] = -vi[n]
114                         if (j[n] == 0 and vj[n] == -1) or (
115                             j[n] == width - 1 and vj[n] == 1
116                         ):
117                             vj[n] = -vj[n]
118
119                         i[n] += vi[n]
120                         j[n] += vj[n]
121
122                         if not (
123                             f_end[i[n], j[n]] == 0
124                             and f_end[i[n] - vi[n], j[n]] == 0
125                             and f_end[i[n], j[n] - vj[n]] == 0
126                         ):
127                             nb_collisions += 1
128
129                         f_end[i[n], j[n]] = c
130                         f_end[i[n] - vi[n], j[n]] = c
131                         f_end[i[n], j[n] - vj[n]] = c
132
133                 iterations.append(f_end.clone())
134
135                 if nb_collisions == 0:
136                     break
137
138             kept_iterations.append(iterations)
139             pairs.append((f_start, f_end))
140
141         result = []
142         for p in pairs:
143             if torch.rand(1) < 0.5:
144                 result.append(
145                     torch.cat(
146                         [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
147                         dim=0,
148                     )[None, :]
149                 )
150             else:
151                 result.append(
152                     torch.cat(
153                         [
154                             p[1].flatten(),
155                             torch.tensor([token_backward]),
156                             p[0].flatten(),
157                         ],
158                         dim=0,
159                     )[None, :]
160                 )
161
162         if return_iterations:
163             # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
164             return torch.cat(result, dim=0), kept_iterations
165         else:
166             return torch.cat(result, dim=0)
167
168     ######################################################################
169
170     def generate_seq_old(
171         nb,
172         height,
173         width,
174         nb_birds=3,
175         nb_iterations=2,
176     ):
177         pairs = []
178
179         for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
180             f_start = torch.zeros(height, width, dtype=torch.int64)
181             f_end = torch.zeros(height, width, dtype=torch.int64)
182             n = torch.arange(f_start.size(0))
183
184             for c in (
185                 (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds]
186                 .sort()
187                 .values
188             ):
189                 i, j = (
190                     torch.randint(height - 2, (1,))[0] + 1,
191                     torch.randint(width - 2, (1,))[0] + 1,
192                 )
193                 vm = torch.randint(4, (1,))[0]
194                 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
195                     2 * (vm % 2) - 1
196                 )
197
198                 f_start[i, j] = c
199                 f_start[i - vi, j - vj] = c
200                 f_start[i + vj, j - vi] = c
201                 f_start[i - vj, j + vi] = c
202
203                 for l in range(nb_iterations):
204                     i += vi
205                     j += vj
206                     if i < 0 or i >= height or j < 0 or j >= width:
207                         i -= vi
208                         j -= vj
209                         vi, vj = -vi, -vj
210                         i += vi
211                         j += vj
212
213                 f_end[i, j] = c
214                 f_end[i - vi, j - vj] = c
215                 f_end[i + vj, j - vi] = c
216                 f_end[i - vj, j + vi] = c
217
218             pairs.append((f_start, f_end))
219
220         result = []
221         for p in pairs:
222             if torch.rand(1) < 0.5:
223                 result.append(
224                     torch.cat(
225                         [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
226                         dim=0,
227                     )[None, :]
228                 )
229             else:
230                 result.append(
231                     torch.cat(
232                         [
233                             p[1].flatten(),
234                             torch.tensor([token_backward]),
235                             p[0].flatten(),
236                         ],
237                         dim=0,
238                     )[None, :]
239                 )
240
241         return torch.cat(result, dim=0)
242
243     def frame2img(x, height, width, upscale=15):
244         x = x.reshape(-1, height, width)
245         m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
246         x = colors[x * m].permute(0, 3, 1, 2)
247         s = x.shape
248         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
249         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
250
251         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
252         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
253         x = x[:, :, 1:, 1:]
254
255         for n in range(m.size(0)):
256             for i in range(m.size(1)):
257                 for j in range(m.size(2)):
258                     if m[n, i, j] == 0:
259                         for k in range(2, upscale - 2):
260                             x[n, :, i * upscale + k, j * upscale + k] = 0
261                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
262
263         return x
264
265     def seq2img(seq, height, width, upscale=15):
266         f_first = seq[:, : height * width].reshape(-1, height, width)
267         f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
268         direction = seq[:, height * width]
269
270         direction_symbol = torch.full(
271             (direction.size(0), height * upscale - 1, upscale), 0
272         )
273         direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
274         separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
275
276         for n in range(direction_symbol.size(0)):
277             if direction[n] == token_forward:
278                 for k in range(upscale):
279                     direction_symbol[
280                         n,
281                         :,
282                         (height * upscale) // 2 - upscale // 2 + k,
283                         3 + upscale // 2 - abs(k - upscale // 2),
284                     ] = 0
285             elif direction[n] == token_backward:
286                 for k in range(upscale):
287                     direction_symbol[
288                         n,
289                         :,
290                         (height * upscale) // 2 - upscale // 2 + k,
291                         3 + abs(k - upscale // 2),
292                     ] = 0
293             else:
294                 for k in range(2, upscale - 2):
295                     direction_symbol[
296                         n, :, (height * upscale) // 2 - upscale // 2 + k, k
297                     ] = 0
298                     direction_symbol[
299                         n,
300                         :,
301                         (height * upscale) // 2 - upscale // 2 + k,
302                         upscale - 1 - k,
303                     ] = 0
304
305         return torch.cat(
306             [
307                 frame2img(f_first, height, width, upscale),
308                 separator,
309                 direction_symbol,
310                 separator,
311                 frame2img(f_second, height, width, upscale),
312             ],
313             dim=3,
314         )
315
316     def seq2str(seq):
317         result = []
318         for s in seq:
319             result.append("".join([token2char[v] for v in s]))
320         return result
321
322
323 ######################################################################
324
325 if __name__ == "__main__":
326     import time
327
328     height, width = 6, 8
329     start_time = time.perf_counter()
330     seq, it = generate_seq(
331         nb=64, height=height, width=width, nb_iterations=100, return_iterations=True
332     )
333     delay = time.perf_counter() - start_time
334     print(f"{seq.size(0)/delay:02f} samples/s")
335
336     print(seq2str(seq[:4]))
337
338     for t in range(len(it[0])):
339         img = torch.cat([frame2img(f[t], height, width) for f in it], dim=0)
340         torchvision.utils.save_image(
341             img.float() / 255.0,
342             f"/tmp/frame_{t:03d}.png",
343             nrow=8,
344             padding=6,
345             pad_value=0,
346         )
347
348     # m = (torch.rand(seq.size()) < 0.05).long()
349     # seq = (1 - m) * seq + m * 23
350
351     img = seq2img(seq, height, width)
352     print(img.size())
353
354     torchvision.utils.save_image(
355         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
356     )