4ca4ba7136b40a5324dcb64ba4c6a3a19523b3c5
[culture.git] / sky.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Sky(problem.Problem):
21     colors = torch.tensor(
22         [
23             [255, 255, 255],
24             [255, 0, 0],
25             [0, 192, 0],
26             [0, 0, 255],
27             [255, 192, 0],
28             [0, 255, 255],
29             [255, 0, 255],
30             [192, 255, 192],
31             [255, 192, 192],
32             [192, 192, 255],
33             [192, 192, 192],
34         ]
35     )
36
37     token_background = 0
38     first_bird_token = 1
39     nb_bird_tokens = colors.size(0) - 1
40     token_forward = first_bird_token + nb_bird_tokens
41     token_backward = token_forward + 1
42
43     token2char = (
44         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
45     )
46
47     def __init__(
48         self,
49         height=6,
50         width=8,
51         nb_birds=3,
52         speed=2,
53         nb_iterations=2,
54         avoid_collision=True,
55     ):
56         self.height = height
57         self.width = width
58         self.nb_birds = nb_birds
59         self.speed = speed
60         self.nb_iterations = nb_iterations
61         self.avoid_collision = avoid_collision
62
63     def direction_tokens(self):
64         return self.token_forward, self.token_backward
65
66     def generate_frame_sequences(self, nb):
67         frame_sequences = []
68
69         for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
70             i, j, vi, vj = (
71                 torch.empty(self.nb_birds, dtype=torch.int64),
72                 torch.empty(self.nb_birds, dtype=torch.int64),
73                 torch.empty(self.nb_birds, dtype=torch.int64),
74                 torch.empty(self.nb_birds, dtype=torch.int64),
75             )
76
77             def collision_okay():
78                 if not self.avoid_collision:
79                     return True
80
81                 count = torch.zeros(self.height, self.width, dtype=torch.int64)
82
83                 for n in range(self.nb_birds):
84                     count[i[n], j[n]] += 1
85                     count[i[n] - vi[n], j[n]] += 1
86                     count[i[n], j[n] - vj[n]] += 1
87
88                 return count.max() <= 1
89
90             col = (
91                 torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
92                 + 1
93             )
94
95             while True:
96                 while True:
97                     for n in range(self.nb_birds):
98                         while True:
99                             i[n] = torch.randint(self.height, (1,))
100                             j[n] = torch.randint(self.width, (1,))
101                             vm = torch.randint(4, (1,))
102                             vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
103                             if (
104                                 i[n] - vi[n] >= 0
105                                 and i[n] - vi[n] < self.height
106                                 and j[n] - vj[n] >= 0
107                                 and j[n] - vj[n] < self.width
108                             ):
109                                 break
110
111                     if collision_okay():
112                         break
113
114                 result = torch.zeros(
115                     self.nb_iterations * self.speed,
116                     self.height,
117                     self.width,
118                     dtype=torch.int64,
119                 )
120
121                 fine = torch.empty(self.nb_iterations * self.speed)
122
123                 t_to_keep = (
124                     torch.arange(self.nb_iterations, device=result.device) * self.speed
125                 )
126
127                 for l in range(self.nb_iterations * self.speed):
128                     fine[l] = collision_okay()
129                     for n in range(self.nb_birds):
130                         c = col[n]
131                         result[l, i[n], j[n]] = c
132                         result[l, i[n] - vi[n], j[n]] = c
133                         result[l, i[n], j[n] - vj[n]] = c
134
135                         if (i[n] == 0 and vi[n] == -1) or (
136                             i[n] == self.height - 1 and vi[n] == 1
137                         ):
138                             vi[n] = -vi[n]
139
140                         if (j[n] == 0 and vj[n] == -1) or (
141                             j[n] == self.width - 1 and vj[n] == 1
142                         ):
143                             vj[n] = -vj[n]
144
145                         i[n] += vi[n]
146                         j[n] += vj[n]
147
148                 result = result[t_to_keep]
149                 fine = fine[t_to_keep]
150
151                 if fine[-1]:
152                     break
153
154             frame_sequences.append(result)
155
156         return frame_sequences
157
158     ######################################################################
159
160     def generate_prompts_and_answers(self, nb):
161         frame_sequences = self.generate_frame_sequences(nb)
162         prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
163         answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
164         return prompts, answers
165
166     def generate_token_sequences(self, nb):
167         frame_sequences = self.generate_frame_sequences(nb)
168
169         result = []
170
171         for frame_sequence in frame_sequences:
172             a = []
173             if torch.rand(1) < 0.5:
174                 for frame in frame_sequence:
175                     if len(a) > 0:
176                         a.append(torch.tensor([self.token_forward]))
177                     a.append(frame.flatten())
178             else:
179                 for frame in reversed(frame_sequence):
180                     if len(a) > 0:
181                         a.append(torch.tensor([self.token_backward]))
182                     a.append(frame.flatten())
183
184             result.append(torch.cat(a, dim=0)[None, :])
185
186         return torch.cat(result, dim=0)
187
188     ######################################################################
189
190     def frame2img(self, x, scale=15):
191         x = x.reshape(-1, self.height, self.width)
192         m = torch.logical_and(
193             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
194         ).long()
195         x = self.colors[x * m].permute(0, 3, 1, 2)
196         s = x.shape
197         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
198         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
199
200         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
201         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
202         x = x[:, :, 1:, 1:]
203
204         for n in range(m.size(0)):
205             for i in range(m.size(1)):
206                 for j in range(m.size(2)):
207                     if m[n, i, j] == 0:
208                         for k in range(2, scale - 2):
209                             for l in [0, 1]:
210                                 x[n, :, i * scale + k, j * scale + k - l] = 0
211                                 x[
212                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
213                                 ] = 0
214
215         return x
216
217     def seq2img(self, seq, scale=15):
218         all = [
219             self.frame2img(
220                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
221                 scale,
222             )
223         ]
224
225         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
226
227         t = self.height * self.width
228
229         while t < seq.size(1):
230             direction_tokens = seq[:, t]
231             t += 1
232
233             direction_images = self.colors[
234                 torch.full(
235                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
236                 )
237             ].permute(0, 3, 1, 2)
238
239             for n in range(direction_tokens.size(0)):
240                 if direction_tokens[n] == self.token_forward:
241                     for k in range(scale):
242                         for l in [0, 1]:
243                             direction_images[
244                                 n,
245                                 :,
246                                 (self.height * scale) // 2 - scale // 2 + k - l,
247                                 3 + scale // 2 - abs(k - scale // 2),
248                             ] = 0
249                 elif direction_tokens[n] == self.token_backward:
250                     for k in range(scale):
251                         for l in [0, 1]:
252                             direction_images[
253                                 n,
254                                 :,
255                                 (self.height * scale) // 2 - scale // 2 + k - l,
256                                 3 + abs(k - scale // 2),
257                             ] = 0
258                 else:
259                     for k in range(2, scale - 2):
260                         for l in [0, 1]:
261                             direction_images[
262                                 n,
263                                 :,
264                                 (self.height * scale) // 2 - scale // 2 + k - l,
265                                 k,
266                             ] = 0
267                             direction_images[
268                                 n,
269                                 :,
270                                 (self.height * scale) // 2 - scale // 2 + k - l,
271                                 scale - 1 - k,
272                             ] = 0
273
274             all += [
275                 separator,
276                 direction_images,
277                 separator,
278                 self.frame2img(
279                     seq[:, t : t + self.height * self.width].reshape(
280                         -1, self.height, self.width
281                     ),
282                     scale,
283                 ),
284             ]
285
286             t += self.height * self.width
287
288         return torch.cat(all, dim=3)
289
290     def seq2str(self, seq):
291         result = []
292         for s in seq:
293             result.append("".join([self.token2char[v] for v in s]))
294         return result
295
296     def save_image(self, input, result_dir, filename):
297         img = self.seq2img(input.to("cpu"))
298         image_name = os.path.join(result_dir, filename)
299         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
300
301     def save_quizzes(self, input, result_dir, filename_prefix):
302         self.save_image(input, result_dir, filename_prefix + ".png")
303
304
305 ######################################################################
306
307 if __name__ == "__main__":
308     import time
309
310     sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
311
312     start_time = time.perf_counter()
313     token_sequences = sky.generate_token_sequences(nb=64)
314     delay = time.perf_counter() - start_time
315     print(f"{token_sequences.size(0)/delay:02f} seq/s")
316
317     # print(sky.seq2str(seq[:4]))
318
319     # for t in range(len(it[0])):
320     # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
321     # torchvision.utils.save_image(
322     # img.float() / 255.0,
323     # f"/tmp/frame_{t:03d}.png",
324     # nrow=8,
325     # padding=6,
326     # pad_value=0,
327     # )
328
329     # m = (torch.rand(seq.size()) < 0.05).long()
330     # seq = (1 - m) * seq + m * 23
331
332     # print(seq.size())
333     img = sky.seq2img(token_sequences)
334     # print(img.size())
335
336     torchvision.utils.save_image(
337         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
338     )