Update.
[culture.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [255, 0, 0],
22         [0, 192, 0],
23         [0, 0, 255],
24         [255, 192, 0],
25         [0, 255, 255],
26         [255, 0, 255],
27         [192, 255, 192],
28         [255, 192, 192],
29         [192, 192, 255],
30         [192, 192, 192],
31     ]
32 )
33
34 token_background = 0
35 first_bird_token = 1
36 nb_bird_tokens = colors.size(0) - 1
37 token_forward = first_bird_token + nb_bird_tokens
38 token_backward = token_forward + 1
39
40 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
41
42
43 def generate_seq(
44     nb,
45     height,
46     width,
47     nb_birds=3,
48     nb_iterations=2,
49 ):
50     pairs = []
51
52     for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
53         while True:
54             f_start = torch.zeros(height, width, dtype=torch.int64)
55
56             i, j, vi, vj = (
57                 torch.empty(nb_birds, dtype=torch.int64),
58                 torch.empty(nb_birds, dtype=torch.int64),
59                 torch.empty(nb_birds, dtype=torch.int64),
60                 torch.empty(nb_birds, dtype=torch.int64),
61             )
62
63             col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
64
65             for n in range(nb_birds):
66                 c = col[n]
67
68                 while True:
69                     i[n], j[n] = (
70                         torch.randint(height, (1,))[0],
71                         torch.randint(width, (1,))[0],
72                     )
73                     vm = torch.randint(4, (1,))[0]
74                     vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
75                     if (
76                         i[n] - vi[n] >= 0
77                         and i[n] - vi[n] < height
78                         and j[n] - vj[n] >= 0
79                         and j[n] - vj[n] < width
80                         and f_start[i[n], j[n]] == 0
81                         and f_start[i[n] - vi[n], j[n]] == 0
82                         and f_start[i[n], j[n] - vj[n]] == 0
83                     ):
84                         break
85
86                 f_start[i[n], j[n]] = c
87                 f_start[i[n] - vi[n], j[n]] = c
88                 f_start[i[n], j[n] - vj[n]] = c
89
90             f_end = f_start.clone()
91
92             for l in range(nb_iterations):
93                 f_end[...] = 0
94                 nb_collisions = 0
95                 for n in range(nb_birds):
96                     c = col[n]
97
98                     pi, pj, pvi, pvj = (
99                         i[n].item(),
100                         j[n].item(),
101                         vi[n].item(),
102                         vj[n].item(),
103                     )
104
105                     if (i[n] == 0 and vi[n] == -1) or (
106                         i[n] == height - 1 and vi[n] == 1
107                     ):
108                         vi[n] = -vi[n]
109                     if (j[n] == 0 and vj[n] == -1) or (
110                         j[n] == width - 1 and vj[n] == 1
111                     ):
112                         vj[n] = -vj[n]
113
114                     i[n] += vi[n]
115                     j[n] += vj[n]
116
117                     if not (
118                         f_end[i[n], j[n]] == 0
119                         and f_end[i[n] - vi[n], j[n]] == 0
120                         and f_end[i[n], j[n] - vj[n]] == 0
121                     ):
122                         nb_collisions += 1
123
124                     f_end[i[n], j[n]] = c
125                     f_end[i[n] - vi[n], j[n]] = c
126                     f_end[i[n], j[n] - vj[n]] = c
127
128             if nb_collisions == 0:
129                 break
130
131         pairs.append((f_start, f_end))
132
133     result = []
134     for p in pairs:
135         if torch.rand(1) < 0.5:
136             result.append(
137                 torch.cat(
138                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
139                     dim=0,
140                 )[None, :]
141             )
142         else:
143             result.append(
144                 torch.cat(
145                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
146                     dim=0,
147                 )[None, :]
148             )
149
150     return torch.cat(result, dim=0)
151
152
153 ######################################################################
154
155
156 def generate_seq_(
157     nb,
158     height,
159     width,
160     nb_birds=3,
161     nb_iterations=2,
162 ):
163     pairs = []
164
165     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
166         f_start = torch.zeros(height, width, dtype=torch.int64)
167         f_end = torch.zeros(height, width, dtype=torch.int64)
168         n = torch.arange(f_start.size(0))
169
170         for c in (
171             (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
172         ):
173             i, j = (
174                 torch.randint(height - 2, (1,))[0] + 1,
175                 torch.randint(width - 2, (1,))[0] + 1,
176             )
177             vm = torch.randint(4, (1,))[0]
178             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
179
180             f_start[i, j] = c
181             f_start[i - vi, j - vj] = c
182             f_start[i + vj, j - vi] = c
183             f_start[i - vj, j + vi] = c
184
185             for l in range(nb_iterations):
186                 i += vi
187                 j += vj
188                 if i < 0 or i >= height or j < 0 or j >= width:
189                     i -= vi
190                     j -= vj
191                     vi, vj = -vi, -vj
192                     i += vi
193                     j += vj
194
195             f_end[i, j] = c
196             f_end[i - vi, j - vj] = c
197             f_end[i + vj, j - vi] = c
198             f_end[i - vj, j + vi] = c
199
200         pairs.append((f_start, f_end))
201
202     result = []
203     for p in pairs:
204         if torch.rand(1) < 0.5:
205             result.append(
206                 torch.cat(
207                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
208                     dim=0,
209                 )[None, :]
210             )
211         else:
212             result.append(
213                 torch.cat(
214                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
215                     dim=0,
216                 )[None, :]
217             )
218
219     return torch.cat(result, dim=0)
220
221
222 def sample2img(seq, height, width, upscale=15):
223     f_first = seq[:, : height * width].reshape(-1, height, width)
224     f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
225     direction = seq[:, height * width]
226
227     def mosaic(x, upscale):
228         x = x.reshape(-1, height, width)
229         m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
230         x = colors[x * m].permute(0, 3, 1, 2)
231         s = x.shape
232         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
233         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
234
235         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
236         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
237         x = x[:, :, 1:, 1:]
238
239         for n in range(m.size(0)):
240             for i in range(m.size(1)):
241                 for j in range(m.size(2)):
242                     if m[n, i, j] == 0:
243                         for k in range(2, upscale - 2):
244                             x[n, :, i * upscale + k, j * upscale + k] = 0
245                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
246
247         return x
248
249     direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
250     direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
251     separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
252
253     for n in range(direction_symbol.size(0)):
254         if direction[n] == token_forward:
255             for k in range(upscale):
256                 direction_symbol[
257                     n,
258                     :,
259                     (height * upscale) // 2 - upscale // 2 + k,
260                     3 + upscale // 2 - abs(k - upscale // 2),
261                 ] = 0
262         elif direction[n] == token_backward:
263             for k in range(upscale):
264                 direction_symbol[
265                     n,
266                     :,
267                     (height * upscale) // 2 - upscale // 2 + k,
268                     3 + abs(k - upscale // 2),
269                 ] = 0
270         else:
271             for k in range(2, upscale - 2):
272                 direction_symbol[
273                     n, :, (height * upscale) // 2 - upscale // 2 + k, k
274                 ] = 0
275                 direction_symbol[
276                     n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
277                 ] = 0
278
279     return torch.cat(
280         [
281             mosaic(f_first, upscale),
282             separator,
283             direction_symbol,
284             separator,
285             mosaic(f_second, upscale),
286         ],
287         dim=3,
288     )
289
290
291 def seq2str(seq):
292     result = []
293     for s in seq:
294         result.append("".join([token2char[v] for v in s]))
295     return result
296
297
298 ######################################################################
299
300 if __name__ == "__main__":
301     import time
302
303     height, width = 6, 8
304     start_time = time.perf_counter()
305     seq = generate_seq(nb=90, height=height, width=width)
306     delay = time.perf_counter() - start_time
307     print(f"{seq.size(0)/delay:02f} samples/s")
308
309     print(seq2str(seq[:4]))
310
311     # m = (torch.rand(seq.size()) < 0.05).long()
312     # seq = (1 - m) * seq + m * 23
313
314     img = sample2img(seq, height, width)
315     print(img.size())
316
317     torchvision.utils.save_image(
318         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
319     )