Update.
[culture.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [0, 0, 255],
22         [0, 0, 255],
23         [0, 192, 0],
24         [0, 255, 0],
25         [0, 255, 127],
26         [0, 255, 255],
27         [0, 255, 255],
28         [30, 144, 255],
29         [64, 224, 208],
30         [65, 105, 225],
31         [75, 0, 130],
32         [106, 90, 205],
33         [128, 0, 128],
34         [135, 206, 235],
35         [192, 192, 192],
36         [220, 20, 60],
37         [250, 128, 114],
38         [255, 0, 0],
39         [255, 0, 255],
40         [255, 105, 180],
41         [255, 127, 80],
42         [255, 165, 0],
43         [255, 182, 193],
44         [255, 20, 147],
45         [255, 200, 0],
46     ]
47 )
48
49 token_background = 0
50 first_bird_token = 1
51 nb_bird_tokens = colors.size(0) - 1
52 token_forward = first_bird_token + nb_bird_tokens
53 token_backward = token_forward + 1
54
55 token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
56
57
58 def generate(
59     nb,
60     height,
61     width,
62     nb_birds=2,
63     nb_iterations=2,
64 ):
65     pairs = []
66
67     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
68         f_start = torch.zeros(height, width, dtype=torch.int64)
69         f_end = torch.zeros(height, width, dtype=torch.int64)
70         n = torch.arange(f_start.size(0))
71
72         for c in (
73             (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
74         ):
75             i, j = (
76                 torch.randint(height - 2, (1,))[0] + 1,
77                 torch.randint(width - 2, (1,))[0] + 1,
78             )
79             vm = torch.randint(4, (1,))[0]
80             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
81
82             f_start[i, j] = c
83             f_start[i - vi, j - vj] = c
84             f_start[i + vj, j - vi] = c
85             f_start[i - vj, j + vi] = c
86
87             for l in range(nb_iterations):
88                 i += vi
89                 j += vj
90                 if i < 0 or i >= height or j < 0 or j >= width:
91                     i -= vi
92                     j -= vj
93                     vi, vj = -vi, -vj
94                     i += vi
95                     j += vj
96
97             f_end[i, j] = c
98             f_end[i - vi, j - vj] = c
99             f_end[i + vj, j - vi] = c
100             f_end[i - vj, j + vi] = c
101
102         pairs.append((f_start, f_end))
103
104     result = []
105     for p in pairs:
106         if torch.rand(1) < 0.5:
107             result.append(
108                 torch.cat(
109                     [p[0].flatten(), torch.tensor([token_forward]), p[1].flatten()],
110                     dim=0,
111                 )[None, :]
112             )
113         else:
114             result.append(
115                 torch.cat(
116                     [p[1].flatten(), torch.tensor([token_backward]), p[0].flatten()],
117                     dim=0,
118                 )[None, :]
119             )
120
121     return torch.cat(result, dim=0)
122
123
124 def sample2img(seq, height, width, upscale=15):
125     f_first = seq[:, : height * width].reshape(-1, height, width)
126     f_second = seq[:, height * width + 1 :].reshape(-1, height, width)
127     direction = seq[:, height * width]
128
129     def mosaic(x, upscale):
130         x = x.reshape(-1, height, width)
131         m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long()
132         x = colors[x * m].permute(0, 3, 1, 2)
133         s = x.shape
134         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
135         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
136
137         x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
138         x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
139         x = x[:, :, 1:, 1:]
140
141         for n in range(m.size(0)):
142             for i in range(m.size(1)):
143                 for j in range(m.size(2)):
144                     if m[n, i, j] == 0:
145                         for k in range(2, upscale - 2):
146                             x[n, :, i * upscale + k, j * upscale + k] = 0
147                             x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
148
149         return x
150
151     direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
152     direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
153     separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
154
155     for n in range(direction_symbol.size(0)):
156         if direction[n] == token_forward:
157             for k in range(upscale):
158                 direction_symbol[
159                     n,
160                     :,
161                     (height * upscale) // 2 - upscale // 2 + k,
162                     3 + abs(k - upscale // 2),
163                 ] = 0
164         elif direction[n] == token_backward:
165             for k in range(upscale):
166                 direction_symbol[
167                     n,
168                     :,
169                     (height * upscale) // 2 - upscale // 2 + k,
170                     3 + upscale // 2 - abs(k - upscale // 2),
171                 ] = 0
172         else:
173             for k in range(2, upscale - 2):
174                 direction_symbol[
175                     n, :, (height * upscale) // 2 - upscale // 2 + k, k
176                 ] = 0
177                 direction_symbol[
178                     n, :, (height * upscale) // 2 - upscale // 2 + k, upscale - 1 - k
179                 ] = 0
180
181     return torch.cat(
182         [
183             mosaic(f_first, upscale),
184             separator,
185             direction_symbol,
186             separator,
187             mosaic(f_second, upscale),
188         ],
189         dim=3,
190     )
191
192
193 def seq2str(seq):
194     result = []
195     for s in seq:
196         result.append("".join([token2char[v] for v in s]))
197     return result
198
199
200 ######################################################################
201
202 if __name__ == "__main__":
203     import time
204
205     height, width = 6, 8
206     start_time = time.perf_counter()
207     seq = generate(nb=90, height=height, width=width)
208     delay = time.perf_counter() - start_time
209     print(f"{seq.size(0)/delay:02f} samples/s")
210
211     print(seq2str(seq[:4]))
212
213     # m = (torch.rand(seq.size()) < 0.05).long()
214     # seq = (1 - m) * seq + m * 23
215
216     img = sample2img(seq, height, width)
217     print(img.size())
218
219     torchvision.utils.save_image(
220         img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
221     )