Update.
[culture.git] / sku.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [255, 0, 0],
22         [0, 192, 0],
23         [0, 0, 255],
24         [255, 192, 0],
25         [0, 255, 255],
26         [255, 0, 255],
27         [192, 255, 192],
28         [255, 192, 192],
29         [192, 192, 255],
30         [192, 192, 192],
31     ]
32 )
33
34 token_background = 0
35 first_bird_token = 1
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
39
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
41
42
43 def generate_seq(
44     nb, height, width, nb_birds=3, nb_iterations=2, return_iterations=False
45 ):
46     pairs = []
47     kept_iterations = []
48
49     for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
50         while True:
51             iterations = []
52
53             f_start = torch.zeros(height, width, dtype=torch.int64)
54
55             i, j, vi, vj = (
56                 torch.empty(nb_birds, dtype=torch.int64),
57                 torch.empty(nb_birds, dtype=torch.int64),
58                 torch.empty(nb_birds, dtype=torch.int64),
59                 torch.empty(nb_birds, dtype=torch.int64),
60             )
61
62             col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
63
64             for n in range(nb_birds):
65                 c = col[n]
66
67                 while True:
68                     i[n], j[n] = (
69                         torch.randint(height, (1,))[0],
70                         torch.randint(width, (1,))[0],
71                     )
72                     vm = torch.randint(4, (1,))[0]
73                     vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
74                     if (
75                         i[n] - vi[n] >= 0
76                         and i[n] - vi[n] < height
77                         and j[n] - vj[n] >= 0
78                         and j[n] - vj[n] < width
79                         and f_start[i[n], j[n]] == 0
80                         and f_start[i[n] - vi[n], j[n]] == 0
81                         and f_start[i[n], j[n] - vj[n]] == 0
82                     ):
83                         break
84
85                 f_start[i[n], j[n]] = c
86                 f_start[i[n] - vi[n], j[n]] = c
87                 f_start[i[n], j[n] - vj[n]] = c
88
89             f_end = f_start.clone()
90
91             for l in range(nb_iterations):
92                 iterations.append(f_end.clone())
93                 f_end[...] = 0
94                 nb_collisions = 0
95                 for n in range(nb_birds):
96                     c = col[n]
97
98                     pi, pj, pvi, pvj = (
99                         i[n].item(),
100                         j[n].item(),
101                         vi[n].item(),
102                         vj[n].item(),
103                     )
104
105                     if (i[n] == 0 and vi[n] == -1) or (
106                         i[n] == height - 1 and vi[n] == 1
107                     ):
108                         vi[n] = -vi[n]
109                     if (j[n] == 0 and vj[n] == -1) or (
110                         j[n] == width - 1 and vj[n] == 1
111                     ):
112                         vj[n] = -vj[n]
113
114                     i[n] += vi[n]
115                     j[n] += vj[n]
116
117                     if not (
118                         f_end[i[n], j[n]] == 0
119                         and f_end[i[n] - vi[n], j[n]] == 0
120                         and f_end[i[n], j[n] - vj[n]] == 0
121                     ):
122                         nb_collisions += 1
123
124                     f_end[i[n], j[n]] = c
125                     f_end[i[n] - vi[n], j[n]] = c
126                     f_end[i[n], j[n] - vj[n]] = c
127
128             iterations.append(f_end.clone())
129
130             if nb_collisions == 0:
131                 break
132
133         kept_iterations.append(iterations)
134         pairs.append((f_start, f_end))
135
136     result = []
137     for p in pairs:
138         if torch.rand(1) < 0.5:
139             result.append(
140                 torch.cat(
141                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
142                     dim=0,
143                 )[None, :]
144             )
145         else:
146             result.append(
147                 torch.cat(
148                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
149                     dim=0,
150                 )[None, :]
151             )
152
153     if return_iterations:
154         # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0)
155         return torch.cat(result, dim=0), kept_iterations
156     else:
157         return torch.cat(result, dim=0)
158
159
160 ######################################################################
161
162
163 def generate_seq_old(
164     nb,
165     height,
166     width,
167     nb_birds=3,
168     nb_iterations=2,
169 ):
170     pairs = []
171
172     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
173         f_start = torch.zeros(height, width, dtype=torch.int64)
174         f_end = torch.zeros(height, width, dtype=torch.int64)
175         n = torch.arange(f_start.size(0))
176
177         for c in (
178             (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
179         ):
180             i, j = (
181                 torch.randint(height - 2, (1,))[0] + 1,
182                 torch.randint(width - 2, (1,))[0] + 1,
183             )
184             vm = torch.randint(4, (1,))[0]
185             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
186
187             f_start[i, j] = c
188             f_start[i - vi, j - vj] = c
189             f_start[i + vj, j - vi] = c
190             f_start[i - vj, j + vi] = c
191
192             for l in range(nb_iterations):
193                 i += vi
194                 j += vj
195                 if i < 0 or i >= height or j < 0 or j >= width:
196                     i -= vi
197                     j -= vj
198                     vi, vj = -vi, -vj
199                     i += vi
200                     j += vj
201
202             f_end[i, j] = c
203             f_end[i - vi, j - vj] = c
204             f_end[i + vj, j - vi] = c
205             f_end[i - vj, j + vi] = c
206
207         pairs.append((f_start, f_end))
208
209     result = []
210     for p in pairs:
211         if torch.rand(1) < 0.5:
212             result.append(
213                 torch.cat(
214                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
215                     dim=0,
216                 )[None, :]
217             )
218         else:
219             result.append(
220                 torch.cat(
221                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
222                     dim=0,
223                 )[None, :]
224             )
225
226     return torch.cat(result, dim=0)
227
228
229 def frame2img(x, height, width, upscale=15):
230     x = x.reshape(-1, height, width)
231     m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
232     x = colors[x * m].permute(0, 3, 1, 2)
233     s = x.shape
234     x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
235     x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
236
237     x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
238     x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
239     x = x[:, :, 1:, 1:]
240
241     for n in range(m.size(0)):
242         for i in range(m.size(1)):
243             for j in range(m.size(2)):
244                 if m[n, i, j] == 0:
245                     for k in range(2, upscale - 2):
246                         x[n, :, i * upscale + k, j * upscale + k] = 0
247                         x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
248
249     return x
250
251
252 def seq2img(seq, height, width, upscale=15):
253     f_first = seq[:, : height * width].reshape(-1, height, width)
254     f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
255     direction = seq[:, height * width]
256
257     direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
258     direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
259     separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
260
261     for n in range(direction_symbol.size(0)):
262         if direction[n] == token_forward:
263             for k in range(upscale):
264                 direction_symbol[
265                     n,
266                     :,
267                     (height * upscale) // 2 - upscale // 2 + k,
268                     3 + upscale // 2 - abs(k - upscale // 2),
269                 ] = 0
270         elif direction[n] == token_backward:
271             for k in range(upscale):
272                 direction_symbol[
273                     n,
274                     :,
275                     (height * upscale) // 2 - upscale // 2 + k,
276                     3 + abs(k - upscale // 2),
277                 ] = 0
278         else:
279             for k in range(2, upscale - 2):
280                 direction_symbol[
281                     n, :, (height * upscale) // 2 - upscale // 2 + k, k
282                 ] = 0
283                 direction_symbol[
284                     n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
285                 ] = 0
286
287     return torch.cat(
288         [
289             frame2img(f_first, height, width, upscale),
290             separator,
291             direction_symbol,
292             separator,
293             frame2img(f_second, height, width, upscale),
294         ],
295         dim=3,
296     )
297
298
299 def seq2str(seq):
300     result = []
301     for s in seq:
302         result.append("".join([token2char[v] for v in s]))
303     return result
304
305
306 ######################################################################
307
308 if __name__ == "__main__":
309     import time
310
311     height, width = 6, 8
312     start_time = time.perf_counter()
313     seq, it = generate_seq(
314         nb=64, height=height, width=width, nb_iterations=100, return_iterations=True
315     )
316     delay = time.perf_counter() - start_time
317     print(f"{seq.size(0)/delay:02f} samples/s")
318
319     print(seq2str(seq[:4]))
320
321     for t in range(len(it[0])):
322         img = torch.cat([frame2img(f[t], height, width) for f in it], dim=0)
323         torchvision.utils.save_image(
324             img.float() / 255.0,
325             f"/tmp/frame_{t:03d}.png",
326             nrow=8,
327             padding=6,
328             pad_value=0,
329         )
330
331     # m = (torch.rand(seq.size()) < 0.05).long()
332     # seq = (1 - m) * seq + m * 23
333
334     img = seq2img(seq, height, width)
335     print(img.size())
336
337     torchvision.utils.save_image(
338         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
339     )