a90e37d542aebd5260b3dbdb7118c0ba4e2cc33f
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 class Sky:
19     colors = torch.tensor(
20         [
21             [255, 255, 255],
22             [255, 0, 0],
23             [0, 192, 0],
24             [0, 0, 255],
25             [255, 192, 0],
26             [0, 255, 255],
27             [255, 0, 255],
28             [192, 255, 192],
29             [255, 192, 192],
30             [192, 192, 255],
31             [192, 192, 192],
32         ]
33     )
34
35     token_background = 0
36     first_bird_token = 1
37     nb_bird_tokens = colors.size(0) - 1
38     token_forward = first_bird_token + nb_bird_tokens
39     token_backward = token_forward + 1
40
41     token2char = (
42         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
43     )
44
45     def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
46         self.height = height
47         self.width = width
48         self.nb_birds = nb_birds
49         self.nb_iterations = nb_iterations
50
51     def generate_seq(self, nb, return_iterations=False):
52         pairs = []
53         kept_iterations = []
54
55         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
56             while True:
57                 iterations = []
58
59                 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
60
61                 i, j, vi, vj = (
62                     torch.empty(self.nb_birds, dtype=torch.int64),
63                     torch.empty(self.nb_birds, dtype=torch.int64),
64                     torch.empty(self.nb_birds, dtype=torch.int64),
65                     torch.empty(self.nb_birds, dtype=torch.int64),
66                 )
67
68                 col = (
69                     torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
70                     .sort()
71                     .values
72                     + 1
73                 )
74
75                 for n in range(self.nb_birds):
76                     c = col[n]
77
78                     while True:
79                         i[n], j[n] = (
80                             torch.randint(self.height, (1,))[0],
81                             torch.randint(self.width, (1,))[0],
82                         )
83                         vm = torch.randint(4, (1,))[0]
84                         vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
85                         if (
86                             i[n] - vi[n] >= 0
87                             and i[n] - vi[n] < self.height
88                             and j[n] - vj[n] >= 0
89                             and j[n] - vj[n] < self.width
90                             and f_start[i[n], j[n]] == 0
91                             and f_start[i[n] - vi[n], j[n]] == 0
92                             and f_start[i[n], j[n] - vj[n]] == 0
93                         ):
94                             break
95
96                     f_start[i[n], j[n]] = c
97                     f_start[i[n] - vi[n], j[n]] = c
98                     f_start[i[n], j[n] - vj[n]] = c
99
100                 f_end = f_start.clone()
101
102                 for l in range(self.nb_iterations):
103                     iterations.append(f_end.clone())
104                     f_end[...] = 0
105                     nb_collisions = 0
106                     for n in range(self.nb_birds):
107                         c = col[n]
108
109                         pi, pj, pvi, pvj = (
110                             i[n].item(),
111                             j[n].item(),
112                             vi[n].item(),
113                             vj[n].item(),
114                         )
115
116                         if (i[n] == 0 and vi[n] == -1) or (
117                             i[n] == self.height - 1 and vi[n] == 1
118                         ):
119                             vi[n] = -vi[n]
120                         if (j[n] == 0 and vj[n] == -1) or (
121                             j[n] == self.width - 1 and vj[n] == 1
122                         ):
123                             vj[n] = -vj[n]
124
125                         i[n] += vi[n]
126                         j[n] += vj[n]
127
128                         if not (
129                             f_end[i[n], j[n]] == 0
130                             and f_end[i[n] - vi[n], j[n]] == 0
131                             and f_end[i[n], j[n] - vj[n]] == 0
132                         ):
133                             nb_collisions += 1
134
135                         f_end[i[n], j[n]] = c
136                         f_end[i[n] - vi[n], j[n]] = c
137                         f_end[i[n], j[n] - vj[n]] = c
138
139                 iterations.append(f_end.clone())
140
141                 if nb_collisions == 0:
142                     break
143
144             kept_iterations.append(iterations)
145             pairs.append((f_start, f_end))
146
147         result = []
148         for p in pairs:
149             if torch.rand(1) < 0.5:
150                 result.append(
151                     torch.cat(
152                         [
153                             p[0].flatten(),
154                             torch.tensor([self.token_forward]),
155                             p[1].flatten(),
156                         ],
157                         dim=0,
158                     )[None, :]
159                 )
160             else:
161                 result.append(
162                     torch.cat(
163                         [
164                             p[1].flatten(),
165                             torch.tensor([self.token_backward]),
166                             p[0].flatten(),
167                         ],
168                         dim=0,
169                     )[None, :]
170                 )
171
172         if return_iterations:
173             # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
174             return torch.cat(result, dim=0), kept_iterations
175         else:
176             return torch.cat(result, dim=0)
177
178     ######################################################################
179
180     def generate_seq_old(
181         self,
182         nb,
183     ):
184         pairs = []
185
186         for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
187             f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
188             f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
189             n = torch.arange(f_start.size(0))
190
191             for c in (
192                 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
193                     : self.nb_birds
194                 ]
195                 .sort()
196                 .values
197             ):
198                 i, j = (
199                     torch.randint(self.height - 2, (1,))[0] + 1,
200                     torch.randint(self.width - 2, (1,))[0] + 1,
201                 )
202                 vm = torch.randint(4, (1,))[0]
203                 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
204                     2 * (vm % 2) - 1
205                 )
206
207                 f_start[i, j] = c
208                 f_start[i - vi, j - vj] = c
209                 f_start[i + vj, j - vi] = c
210                 f_start[i - vj, j + vi] = c
211
212                 for l in range(self.nb_iterations):
213                     i += vi
214                     j += vj
215                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
216                         i -= vi
217                         j -= vj
218                         vi, vj = -vi, -vj
219                         i += vi
220                         j += vj
221
222                 f_end[i, j] = c
223                 f_end[i - vi, j - vj] = c
224                 f_end[i + vj, j - vi] = c
225                 f_end[i - vj, j + vi] = c
226
227             pairs.append((f_start, f_end))
228
229         result = []
230         for p in pairs:
231             if torch.rand(1) < 0.5:
232                 result.append(
233                     torch.cat(
234                         [
235                             p[0].flatten(),
236                             torch.tensor([self.token_forward]),
237                             p[1].flatten(),
238                         ],
239                         dim=0,
240                     )[None, :]
241                 )
242             else:
243                 result.append(
244                     torch.cat(
245                         [
246                             p[1].flatten(),
247                             torch.tensor([self.token_backward]),
248                             p[0].flatten(),
249                         ],
250                         dim=0,
251                     )[None, :]
252                 )
253
254         return torch.cat(result, dim=0)
255
256     def frame2img(self, x, upscale=15):
257         x = x.reshape(-1, self.height, self.width)
258         m = torch.logical_and(
259             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
260         ).long()
261         x = self.colors[x * m].permute(0, 3, 1, 2)
262         s = x.shape
263         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
264         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
265
266         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
267         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
268         x = x[:, :, 1:, 1:]
269
270         for n in range(m.size(0)):
271             for i in range(m.size(1)):
272                 for j in range(m.size(2)):
273                     if m[n, i, j] == 0:
274                         for k in range(2, upscale - 2):
275                             x[n, :, i * upscale + k, j * upscale + k] = 0
276                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
277
278         return x
279
280     def seq2img(self, seq, upscale=15):
281         f_first = seq[:, : self.height * self.width].reshape(
282             -1, self.height, self.width
283         )
284         f_second = seq[:, self.height * self.width + 1 :].reshape(
285             -1, self.height, self.width
286         )
287         direction = seq[:, self.height * self.width]
288
289         direction_symbol = torch.full(
290             (direction.size(0), self.height * upscale - 1, upscale), 0
291         )
292         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
293         separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
294
295         for n in range(direction_symbol.size(0)):
296             if direction[n] == self.token_forward:
297                 for k in range(upscale):
298                     direction_symbol[
299                         n,
300                         :,
301                         (self.height * upscale) // 2 - upscale // 2 + k,
302                         3 + upscale // 2 - abs(k - upscale // 2),
303                     ] = 0
304             elif direction[n] == self.token_backward:
305                 for k in range(upscale):
306                     direction_symbol[
307                         n,
308                         :,
309                         (self.height * upscale) // 2 - upscale // 2 + k,
310                         3 + abs(k - upscale // 2),
311                     ] = 0
312             else:
313                 for k in range(2, upscale - 2):
314                     direction_symbol[
315                         n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
316                     ] = 0
317                     direction_symbol[
318                         n,
319                         :,
320                         (self.height * upscale) // 2 - upscale // 2 + k,
321                         upscale - 1 - k,
322                     ] = 0
323
324         return torch.cat(
325             [
326                 self.frame2img(f_first, upscale),
327                 separator,
328                 direction_symbol,
329                 separator,
330                 self.frame2img(f_second, upscale),
331             ],
332             dim=3,
333         )
334
335     def seq2str(self, seq):
336         result = []
337         for s in seq:
338             result.append("".join([self.token2char[v] for v in s]))
339         return result
340
341
342 ######################################################################
343
344 if __name__ == "__main__":
345     import time
346
347     sky = Sky(height=6, width=8, nb_iterations=100)
348
349     start_time = time.perf_counter()
350     seq, it = sky.generate_seq(nb=64, return_iterations=True)
351     delay = time.perf_counter() - start_time
352     print(f"{seq.size(0)/delay:02f} samples/s")
353
354     print(sky.seq2str(seq[:4]))
355
356     for t in range(len(it[0])):
357         img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
358         torchvision.utils.save_image(
359             img.float() / 255.0,
360             f"/tmp/frame_{t:03d}.png",
361             nrow=8,
362             padding=6,
363             pad_value=0,
364         )
365
366     # m = (torch.rand(seq.size()) < 0.05).long()
367     # seq = (1 - m) * seq + m * 23
368
369     img = sky.seq2img(seq)
370     print(img.size())
371
372     torchvision.utils.save_image(
373         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
374     )