cb25ea0ec335fc1823d2f4d7d044ed7908184a0f
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 class Problem:
19     def generate_seq(self, nb_train_samples):
20         pass
21
22     def save_quizzes(self, input, result_dir, filename_prefix, logger):
23         pass
24
25     def direction_tokens(self):
26         pass
27
28
29 class Sky:
30     colors = torch.tensor(
31         [
32             [255, 255, 255],
33             [255, 0, 0],
34             [0, 192, 0],
35             [0, 0, 255],
36             [255, 192, 0],
37             [0, 255, 255],
38             [255, 0, 255],
39             [192, 255, 192],
40             [255, 192, 192],
41             [192, 192, 255],
42             [192, 192, 192],
43         ]
44     )
45
46     token_background = 0
47     first_bird_token = 1
48     nb_bird_tokens = colors.size(0) - 1
49     token_forward = first_bird_token + nb_bird_tokens
50     token_backward = token_forward + 1
51
52     token2char = (
53         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
54     )
55
56     def __init__(self, height=6, width=8, nb_birds=3, nb_iterations=2):
57         self.height = height
58         self.width = width
59         self.nb_birds = nb_birds
60         self.nb_iterations = nb_iterations
61
62     def direction_tokens(self):
63         return self.token_forward, self.token_backward
64
65     def generate_seq(self, nb, return_iterations=False):
66         pairs = []
67         kept_iterations = []
68
69         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
70             while True:
71                 iterations = []
72
73                 f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
74
75                 i, j, vi, vj = (
76                     torch.empty(self.nb_birds, dtype=torch.int64),
77                     torch.empty(self.nb_birds, dtype=torch.int64),
78                     torch.empty(self.nb_birds, dtype=torch.int64),
79                     torch.empty(self.nb_birds, dtype=torch.int64),
80                 )
81
82                 col = (
83                     torch.randperm(self.colors.size(0) - 1)[: self.nb_birds]
84                     .sort()
85                     .values
86                     + 1
87                 )
88
89                 for n in range(self.nb_birds):
90                     c = col[n]
91
92                     while True:
93                         i[n], j[n] = (
94                             torch.randint(self.height, (1,))[0],
95                             torch.randint(self.width, (1,))[0],
96                         )
97                         vm = torch.randint(4, (1,))[0]
98                         vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
99                         if (
100                             i[n] - vi[n] >= 0
101                             and i[n] - vi[n] < self.height
102                             and j[n] - vj[n] >= 0
103                             and j[n] - vj[n] < self.width
104                             and f_start[i[n], j[n]] == 0
105                             and f_start[i[n] - vi[n], j[n]] == 0
106                             and f_start[i[n], j[n] - vj[n]] == 0
107                         ):
108                             break
109
110                     f_start[i[n], j[n]] = c
111                     f_start[i[n] - vi[n], j[n]] = c
112                     f_start[i[n], j[n] - vj[n]] = c
113
114                 f_end = f_start.clone()
115
116                 for l in range(self.nb_iterations):
117                     iterations.append(f_end.clone())
118                     f_end[...] = 0
119                     nb_collisions = 0
120                     for n in range(self.nb_birds):
121                         c = col[n]
122
123                         pi, pj, pvi, pvj = (
124                             i[n].item(),
125                             j[n].item(),
126                             vi[n].item(),
127                             vj[n].item(),
128                         )
129
130                         if (i[n] == 0 and vi[n] == -1) or (
131                             i[n] == self.height - 1 and vi[n] == 1
132                         ):
133                             vi[n] = -vi[n]
134                         if (j[n] == 0 and vj[n] == -1) or (
135                             j[n] == self.width - 1 and vj[n] == 1
136                         ):
137                             vj[n] = -vj[n]
138
139                         i[n] += vi[n]
140                         j[n] += vj[n]
141
142                         if not (
143                             f_end[i[n], j[n]] == 0
144                             and f_end[i[n] - vi[n], j[n]] == 0
145                             and f_end[i[n], j[n] - vj[n]] == 0
146                         ):
147                             nb_collisions += 1
148
149                         f_end[i[n], j[n]] = c
150                         f_end[i[n] - vi[n], j[n]] = c
151                         f_end[i[n], j[n] - vj[n]] = c
152
153                 iterations.append(f_end.clone())
154
155                 if nb_collisions == 0:
156                     break
157
158             kept_iterations.append(iterations)
159             pairs.append((f_start, f_end))
160
161         result = []
162         for p in pairs:
163             if torch.rand(1) < 0.5:
164                 result.append(
165                     torch.cat(
166                         [
167                             p[0].flatten(),
168                             torch.tensor([self.token_forward]),
169                             p[1].flatten(),
170                         ],
171                         dim=0,
172                     )[None, :]
173                 )
174             else:
175                 result.append(
176                     torch.cat(
177                         [
178                             p[1].flatten(),
179                             torch.tensor([self.token_backward]),
180                             p[0].flatten(),
181                         ],
182                         dim=0,
183                     )[None, :]
184                 )
185
186         if return_iterations:
187             # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
188             return torch.cat(result, dim=0), kept_iterations
189         else:
190             return torch.cat(result, dim=0)
191
192     ######################################################################
193
194     def generate_seq_old(
195         self,
196         nb,
197     ):
198         pairs = []
199
200         for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
201             f_start = torch.zeros(self.height, self.width, dtype=torch.int64)
202             f_end = torch.zeros(self.height, self.width, dtype=torch.int64)
203             n = torch.arange(f_start.size(0))
204
205             for c in (
206                 (torch.randperm(self.nb_bird_tokens) + self.first_bird_token)[
207                     : self.nb_birds
208                 ]
209                 .sort()
210                 .values
211             ):
212                 i, j = (
213                     torch.randint(self.height - 2, (1,))[0] + 1,
214                     torch.randint(self.width - 2, (1,))[0] + 1,
215                 )
216                 vm = torch.randint(4, (1,))[0]
217                 vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (
218                     2 * (vm % 2) - 1
219                 )
220
221                 f_start[i, j] = c
222                 f_start[i - vi, j - vj] = c
223                 f_start[i + vj, j - vi] = c
224                 f_start[i - vj, j + vi] = c
225
226                 for l in range(self.nb_iterations):
227                     i += vi
228                     j += vj
229                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
230                         i -= vi
231                         j -= vj
232                         vi, vj = -vi, -vj
233                         i += vi
234                         j += vj
235
236                 f_end[i, j] = c
237                 f_end[i - vi, j - vj] = c
238                 f_end[i + vj, j - vi] = c
239                 f_end[i - vj, j + vi] = c
240
241             pairs.append((f_start, f_end))
242
243         result = []
244         for p in pairs:
245             if torch.rand(1) < 0.5:
246                 result.append(
247                     torch.cat(
248                         [
249                             p[0].flatten(),
250                             torch.tensor([self.token_forward]),
251                             p[1].flatten(),
252                         ],
253                         dim=0,
254                     )[None, :]
255                 )
256             else:
257                 result.append(
258                     torch.cat(
259                         [
260                             p[1].flatten(),
261                             torch.tensor([self.token_backward]),
262                             p[0].flatten(),
263                         ],
264                         dim=0,
265                     )[None, :]
266                 )
267
268         return torch.cat(result, dim=0)
269
270     def frame2img(self, x, upscale=15):
271         x = x.reshape(-1, self.height, self.width)
272         m = torch.logical_and(
273             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
274         ).long()
275         x = self.colors[x * m].permute(0, 3, 1, 2)
276         s = x.shape
277         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
278         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
279
280         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
281         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
282         x = x[:, :, 1:, 1:]
283
284         for n in range(m.size(0)):
285             for i in range(m.size(1)):
286                 for j in range(m.size(2)):
287                     if m[n, i, j] == 0:
288                         for k in range(2, upscale - 2):
289                             x[n, :, i * upscale + k, j * upscale + k] = 0
290                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
291
292         return x
293
294     def seq2img(self, seq, upscale=15):
295         f_first = seq[:, : self.height * self.width].reshape(
296             -1, self.height, self.width
297         )
298         f_second = seq[:, self.height * self.width + 1 :].reshape(
299             -1, self.height, self.width
300         )
301         direction = seq[:, self.height * self.width]
302
303         direction_symbol = torch.full(
304             (direction.size(0), self.height * upscale - 1, upscale), 0
305         )
306         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
307         separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
308
309         for n in range(direction_symbol.size(0)):
310             if direction[n] == self.token_forward:
311                 for k in range(upscale):
312                     direction_symbol[
313                         n,
314                         :,
315                         (self.height * upscale) // 2 - upscale // 2 + k,
316                         3 + upscale // 2 - abs(k - upscale // 2),
317                     ] = 0
318             elif direction[n] == self.token_backward:
319                 for k in range(upscale):
320                     direction_symbol[
321                         n,
322                         :,
323                         (self.height * upscale) // 2 - upscale // 2 + k,
324                         3 + abs(k - upscale // 2),
325                     ] = 0
326             else:
327                 for k in range(2, upscale - 2):
328                     direction_symbol[
329                         n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
330                     ] = 0
331                     direction_symbol[
332                         n,
333                         :,
334                         (self.height * upscale) // 2 - upscale // 2 + k,
335                         upscale - 1 - k,
336                     ] = 0
337
338         return torch.cat(
339             [
340                 self.frame2img(f_first, upscale),
341                 separator,
342                 direction_symbol,
343                 separator,
344                 self.frame2img(f_second, upscale),
345             ],
346             dim=3,
347         )
348
349     def seq2str(self, seq):
350         result = []
351         for s in seq:
352             result.append("".join([self.token2char[v] for v in s]))
353         return result
354
355     def save_image(self, input, result_dir, filename, logger):
356         img = self.seq2img(input.to("cpu"))
357         image_name = os.path.join(result_dir, filename)
358         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
359         logger(f"wrote {image_name}")
360
361     def save_quizzes(self, input, result_dir, filename_prefix, logger):
362         self.save_image(input, result_dir, filename_prefix + ".png", logger)
363
364
365 ######################################################################
366
367 if __name__ == "__main__":
368     import time
369
370     sky = Sky(height=6, width=8, nb_iterations=100)
371
372     start_time = time.perf_counter()
373     seq, it = sky.generate_seq(nb=64, return_iterations=True)
374     delay = time.perf_counter() - start_time
375     print(f"{seq.size(0)/delay:02f} samples/s")
376
377     print(sky.seq2str(seq[:4]))
378
379     for t in range(len(it[0])):
380         img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
381         torchvision.utils.save_image(
382             img.float() / 255.0,
383             f"/tmp/frame_{t:03d}.png",
384             nrow=8,
385             padding=6,
386             pad_value=0,
387         )
388
389     # m = (torch.rand(seq.size()) < 0.05).long()
390     # seq = (1 - m) * seq + m * 23
391
392     img = sky.seq2img(seq)
393     print(img.size())
394
395     torchvision.utils.save_image(
396         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
397     )