68f46de76b1bea49fd25936696a604d000d48600
[culture.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [255, 20, 147],
22         [0, 0, 255],
23         [0, 192, 0],
24         [0, 255, 255],
25         [192, 192, 192],
26         [106, 90, 205],
27         [255, 0, 0],
28         [220, 20, 60],
29         [65, 105, 225],
30         [255, 200, 0],
31         # [255, 182, 193],
32         # [75, 0, 130],
33         # [128, 0, 128],
34         # [30, 144, 255],
35         # [135, 206, 235],
36         # [0, 255, 0],
37         # [64, 224, 208],
38         # [250, 128, 114],
39         # [255, 165, 0],
40         # [0, 255, 255],
41     ]
42 )
43
44 token_background = 0
45 first_bird_token = 1
46 nb_bird_tokens = colors.size(0) - 1
47 token_forward = first_bird_token + nb_bird_tokens
48 token_backward = token_forward + 1
49
50 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
51
52
53 def generate(
54     nb,
55     height,
56     width,
57     nb_birds=3,
58     nb_iterations=1,
59 ):
60     pairs = []
61
62     for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
63         f_start = torch.zeros(height, width, dtype=torch.int64)
64
65         i, j, vi, vj = (
66             torch.empty(nb_birds, dtype=torch.int64),
67             torch.empty(nb_birds, dtype=torch.int64),
68             torch.empty(nb_birds, dtype=torch.int64),
69             torch.empty(nb_birds, dtype=torch.int64),
70         )
71
72         col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
73
74         for n in range(nb_birds):
75             c = col[n]
76
77             while True:
78                 i[n], j[n] = (
79                     torch.randint(height, (1,))[0],
80                     torch.randint(width, (1,))[0],
81                 )
82                 vm = torch.randint(4, (1,))[0]
83                 vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
84                 if (
85                     i[n] - vi[n] >= 0
86                     and i[n] - vi[n] < height
87                     and j[n] - vj[n] >= 0
88                     and j[n] - vj[n] < width
89                     and f_start[i[n], j[n]] == 0
90                     and f_start[i[n] - vi[n], j[n]] == 0
91                     and f_start[i[n], j[n] - vj[n]] == 0
92                 ):
93                     break
94
95             f_start[i[n], j[n]] = c
96             f_start[i[n] - vi[n], j[n]] = c
97             f_start[i[n], j[n] - vj[n]] = c
98
99         f_end = f_start.clone()
100
101         for l in range(nb_iterations):
102             for n in range(nb_birds):
103                 c = col[n]
104                 f_end[i[n], j[n]] = 0
105                 f_end[i[n] - vi[n], j[n]] = 0
106                 f_end[i[n], j[n] - vj[n]] = 0
107
108                 pi, pj, pvi, pvj = i[n].item(), j[n].item(), vi[n].item(), vj[n].item()
109
110                 assert (
111                     f_end[i[n], j[n]] == 0
112                     and f_end[i[n] - vi[n], j[n]] == 0
113                     and f_end[i[n], j[n] - vj[n]] == 0
114                 )
115
116                 if (i[n] == 0 and vi[n] == -1) or (i[n] == height - 1 and vi[n] == 1):
117                     vi[n] = -vi[n]
118                 if (j[n] == 0 and vj[n] == -1) or (j[n] == width - 1 and vj[n] == 1):
119                     vj[n] = -vj[n]
120
121                 i[n] += vi[n]
122                 j[n] += vj[n]
123
124                 if not (
125                     f_end[i[n], j[n]] == 0
126                     and f_end[i[n] - vi[n], j[n]] == 0
127                     and f_end[i[n], j[n] - vj[n]] == 0
128                 ):
129                     i[n], j[n], vi[n], vj[n] = pi, pj, pvi, pvj
130
131                 f_end[i[n], j[n]] = c
132                 f_end[i[n] - vi[n], j[n]] = c
133                 f_end[i[n], j[n] - vj[n]] = c
134
135         pairs.append((f_start, f_end))
136
137     result = []
138     for p in pairs:
139         if torch.rand(1) < 0.5:
140             result.append(
141                 torch.cat(
142                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
143                     dim=0,
144                 )[None, :]
145             )
146         else:
147             result.append(
148                 torch.cat(
149                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
150                     dim=0,
151                 )[None, :]
152             )
153
154     return torch.cat(result, dim=0)
155
156
157 def generate_(
158     nb,
159     height,
160     width,
161     nb_birds=3,
162     nb_iterations=2,
163 ):
164     pairs = []
165
166     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
167         f_start = torch.zeros(height, width, dtype=torch.int64)
168         f_end = torch.zeros(height, width, dtype=torch.int64)
169         n = torch.arange(f_start.size(0))
170
171         for c in (
172             (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
173         ):
174             i, j = (
175                 torch.randint(height - 2, (1,))[0] + 1,
176                 torch.randint(width - 2, (1,))[0] + 1,
177             )
178             vm = torch.randint(4, (1,))[0]
179             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
180
181             f_start[i, j] = c
182             f_start[i - vi, j - vj] = c
183             f_start[i + vj, j - vi] = c
184             f_start[i - vj, j + vi] = c
185
186             for l in range(nb_iterations):
187                 i += vi
188                 j += vj
189                 if i < 0 or i >= height or j < 0 or j >= width:
190                     i -= vi
191                     j -= vj
192                     vi, vj = -vi, -vj
193                     i += vi
194                     j += vj
195
196             f_end[i, j] = c
197             f_end[i - vi, j - vj] = c
198             f_end[i + vj, j - vi] = c
199             f_end[i - vj, j + vi] = c
200
201         pairs.append((f_start, f_end))
202
203     result = []
204     for p in pairs:
205         if torch.rand(1) < 0.5:
206             result.append(
207                 torch.cat(
208                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
209                     dim=0,
210                 )[None, :]
211             )
212         else:
213             result.append(
214                 torch.cat(
215                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
216                     dim=0,
217                 )[None, :]
218             )
219
220     return torch.cat(result, dim=0)
221
222
223 def sample2img(seq, height, width, upscale=15):
224     f_first = seq[:, : height * width].reshape(-1, height, width)
225     f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
226     direction = seq[:, height * width]
227
228     def mosaic(x, upscale):
229         x = x.reshape(-1, height, width)
230         m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
231         x = colors[x * m].permute(0, 3, 1, 2)
232         s = x.shape
233         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
234         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
235
236         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
237         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
238         x = x[:, :, 1:, 1:]
239
240         for n in range(m.size(0)):
241             for i in range(m.size(1)):
242                 for j in range(m.size(2)):
243                     if m[n, i, j] == 0:
244                         for k in range(2, upscale - 2):
245                             x[n, :, i * upscale + k, j * upscale + k] = 0
246                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
247
248         return x
249
250     direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
251     direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
252     separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
253
254     for n in range(direction_symbol.size(0)):
255         if direction[n] == token_forward:
256             for k in range(upscale):
257                 direction_symbol[
258                     n,
259                     :,
260                     (height * upscale) // 2 - upscale // 2 + k,
261                     3 + upscale // 2 - abs(k - upscale // 2),
262                 ] = 0
263         elif direction[n] == token_backward:
264             for k in range(upscale):
265                 direction_symbol[
266                     n,
267                     :,
268                     (height * upscale) // 2 - upscale // 2 + k,
269                     3 + abs(k - upscale // 2),
270                 ] = 0
271         else:
272             for k in range(2, upscale - 2):
273                 direction_symbol[
274                     n, :, (height * upscale) // 2 - upscale // 2 + k, k
275                 ] = 0
276                 direction_symbol[
277                     n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
278                 ] = 0
279
280     return torch.cat(
281         [
282             mosaic(f_first, upscale),
283             separator,
284             direction_symbol,
285             separator,
286             mosaic(f_second, upscale),
287         ],
288         dim=3,
289     )
290
291
292 def seq2str(seq):
293     result = []
294     for s in seq:
295         result.append("".join([token2char[v] for v in s]))
296     return result
297
298
299 ######################################################################
300
301 if __name__ == "__main__":
302     import time
303
304     height, width = 6, 8
305     start_time = time.perf_counter()
306     seq = generate(nb=90, height=height, width=width)
307     delay = time.perf_counter() - start_time
308     print(f"{seq.size(0)/delay:02f} samples/s")
309
310     print(seq2str(seq[:4]))
311
312     # m = (torch.rand(seq.size()) < 0.05).long()
313     # seq = (1 - m) * seq + m * 23
314
315     img = sample2img(seq, height, width)
316     print(img.size())
317
318     torchvision.utils.save_image(
319         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
320     )