fdc1689737831223d9d81925dc6a1a8b39c99d3d
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Sky(problem.Problem):
21     colors = torch.tensor(
22         [
23             [255, 255, 255],
24             [255, 0, 0],
25             [0, 192, 0],
26             [0, 0, 255],
27             [255, 192, 0],
28             [0, 255, 255],
29             [255, 0, 255],
30             [192, 255, 192],
31             [255, 192, 192],
32             [192, 192, 255],
33             [192, 192, 192],
34         ]
35     )
36
37     token_background = 0
38     first_bird_token = 1
39     nb_bird_tokens = colors.size(0) - 1
40     token_forward = first_bird_token + nb_bird_tokens
41     token_backward = token_forward + 1
42
43     token2char = (
44         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
45     )
46
47     def __init__(self, height=6, width=8, nb_birds=3, speed=1, nb_iterations=4):
48         self.height = height
49         self.width = width
50         self.nb_birds = nb_birds
51         self.speed = speed
52         self.nb_iterations = nb_iterations
53
54     def direction_tokens(self):
55         return self.token_forward, self.token_backward
56
57     def generate_seq(self, nb, return_frame_sequences=False):
58         frame_sequences = []
59
60         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
61             result = torch.zeros(
62                 self.nb_iterations, self.height, self.width, dtype=torch.int64
63             )
64
65             i, j, vi, vj = (
66                 torch.empty(self.nb_birds, dtype=torch.int64),
67                 torch.empty(self.nb_birds, dtype=torch.int64),
68                 torch.empty(self.nb_birds, dtype=torch.int64),
69                 torch.empty(self.nb_birds, dtype=torch.int64),
70             )
71
72             col = (
73                 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
74                 + 1
75             )
76
77             for n in range(self.nb_birds):
78                 while True:
79                     i[n] = torch.randint(self.height, (1,))
80                     j[n] = torch.randint(self.width, (1,))
81                     vm = torch.randint(4, (1,))
82                     vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
83                     if (
84                         i[n] - vi[n] >= 0
85                         and i[n] - vi[n] < self.height
86                         and j[n] - vj[n] >= 0
87                         and j[n] - vj[n] < self.width
88                     ):
89                         break
90
91             for l in range(self.nb_iterations):
92                 for n in range(self.nb_birds):
93                     c = col[n]
94                     result[l, i[n], j[n]] = c
95                     result[l, i[n] - vi[n], j[n]] = c
96                     result[l, i[n], j[n] - vj[n]] = c
97
98                     if (i[n] == 0 and vi[n] == -1) or (
99                         i[n] == self.height - 1 and vi[n] == 1
100                     ):
101                         vi[n] = -vi[n]
102
103                     if (j[n] == 0 and vj[n] == -1) or (
104                         j[n] == self.width - 1 and vj[n] == 1
105                     ):
106                         vj[n] = -vj[n]
107
108                     i[n] += vi[n]
109                     j[n] += vj[n]
110
111             frame_sequences.append(result)
112
113         if return_frame_sequences:
114             return frame_sequences
115
116         # Randomize the time direction, annd convert to token
117         # sequences with the time direction tokens added
118
119         result = []
120
121         for frame_sequence in frame_sequences:
122             a = []
123             if torch.rand(1) < 0.5:
124                 for frame in frame_sequence:
125                     if len(a) > 0:
126                         a.append(torch.tensor([self.token_forward]))
127                     a.append(frame.flatten())
128             else:
129                 for frame in reversed(frame_sequence):
130                     if len(a) > 0:
131                         a.append(torch.tensor([self.token_backward]))
132                     a.append(frame.flatten())
133
134             result.append(torch.cat(a, dim=0)[None, :])
135
136         return torch.cat(result, dim=0)
137
138     ######################################################################
139
140     def frame2img(self, x, scale=15):
141         x = x.reshape(-1, self.height, self.width)
142         m = torch.logical_and(
143             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
144         ).long()
145         x = self.colors[x * m].permute(0, 3, 1, 2)
146         s = x.shape
147         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
148         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
149
150         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
151         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
152         x = x[:, :, 1:, 1:]
153
154         for n in range(m.size(0)):
155             for i in range(m.size(1)):
156                 for j in range(m.size(2)):
157                     if m[n, i, j] == 0:
158                         for k in range(2, scale - 2):
159                             for l in [0, 1]:
160                                 x[n, :, i * scale + k, j * scale + k - l] = 0
161                                 x[
162                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
163                                 ] = 0
164
165         return x
166
167     def seq2img(self, seq, scale=15):
168         all = [
169             self.frame2img(
170                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
171                 scale,
172             )
173         ]
174
175         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
176
177         t = self.height * self.width
178
179         while t < seq.size(1):
180             direction_tokens = seq[:, t]
181             t += 1
182
183             direction_images = self.colors[
184                 torch.full(
185                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
186                 )
187             ].permute(0, 3, 1, 2)
188
189             for n in range(direction_tokens.size(0)):
190                 if direction_tokens[n] == self.token_forward:
191                     for k in range(scale):
192                         for l in [0, 1]:
193                             direction_images[
194                                 n,
195                                 :,
196                                 (self.height * scale) // 2 - scale // 2 + k - l,
197                                 3 + scale // 2 - abs(k - scale // 2),
198                             ] = 0
199                 elif direction_tokens[n] == self.token_backward:
200                     for k in range(scale):
201                         for l in [0, 1]:
202                             direction_images[
203                                 n,
204                                 :,
205                                 (self.height * scale) // 2 - scale // 2 + k - l,
206                                 3 + abs(k - scale // 2),
207                             ] = 0
208                 else:
209                     for k in range(2, scale - 2):
210                         for l in [0, 1]:
211                             direction_images[
212                                 n,
213                                 :,
214                                 (self.height * scale) // 2 - scale // 2 + k - l,
215                                 k,
216                             ] = 0
217                             direction_images[
218                                 n,
219                                 :,
220                                 (self.height * scale) // 2 - scale // 2 + k - l,
221                                 scale - 1 - k,
222                             ] = 0
223
224             all += [
225                 separator,
226                 direction_images,
227                 separator,
228                 self.frame2img(
229                     seq[:, t : t + self.height * self.width].reshape(
230                         -1, self.height, self.width
231                     ),
232                     scale,
233                 ),
234             ]
235
236             t += self.height * self.width
237
238         return torch.cat(all, dim=3)
239
240     def seq2str(self, seq):
241         result = []
242         for s in seq:
243             result.append("".join([self.token2char[v] for v in s]))
244         return result
245
246     def save_image(self, input, result_dir, filename):
247         img = self.seq2img(input.to("cpu"))
248         image_name = os.path.join(result_dir, filename)
249         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
250
251     def save_quizzes(self, input, result_dir, filename_prefix):
252         self.save_image(input, result_dir, filename_prefix + ".png")
253
254
255 ######################################################################
256
257 if __name__ == "__main__":
258     import time
259
260     sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
261
262     start_time = time.perf_counter()
263     seq = sky.generate_seq(nb=64)
264     delay = time.perf_counter() - start_time
265     print(f"{seq.size(0)/delay:02f} seq/s")
266
267     # print(sky.seq2str(seq[:4]))
268
269     # for t in range(len(it[0])):
270     # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
271     # torchvision.utils.save_image(
272     # img.float() / 255.0,
273     # f"/tmp/frame_{t:03d}.png",
274     # nrow=8,
275     # padding=6,
276     # pad_value=0,
277     # )
278
279     # m = (torch.rand(seq.size()) < 0.05).long()
280     # seq = (1 - m) * seq + m * 23
281
282     print(seq.size())
283     img = sky.seq2img(seq)
284     print(img.size())
285
286     torchvision.utils.save_image(
287         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
288     )