Update.
[culture.git] / wireworld.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Wireworld(problem.Problem):
21     colors = torch.tensor(
22         [
23             [128, 128, 128],
24             [128, 128, 255],
25             [255, 0, 0],
26             [255, 255, 0],
27         ]
28     )
29
30     token_empty = 0
31     token_head = 1
32     token_tail = 2
33     token_conductor = 3
34     token_forward = 4
35     token_backward = 5
36
37     token2char = (
38         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
39     )
40
41     def __init__(
42         self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
43     ):
44         self.height = height
45         self.width = width
46         self.nb_objects = nb_objects
47         self.nb_walls = nb_walls
48         self.speed = speed
49         self.nb_iterations = nb_iterations
50
51     def direction_tokens(self):
52         return self.token_forward, self.token_backward
53
54     def generate_frame_sequences(self, nb):
55         result = []
56         N = 100
57         for _ in tqdm.tqdm(
58             range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
59         ):
60             result.append(self.generate_frame_sequences_hard(100))
61         return torch.cat(result, dim=0)[:nb]
62
63     def generate_frame_sequences_hard(self, nb):
64         frame_sequences = []
65
66         result = torch.full(
67             (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
68             self.token_empty,
69         )
70
71         for n in range(result.size(0)):
72             while True:
73                 i = torch.randint(self.height, (1,))
74                 j = torch.randint(self.width, (1,))
75                 v = torch.randint(2, (2,))
76                 vi = v[0] * (v[1] * 2 - 1)
77                 vj = (1 - v[0]) * (v[1] * 2 - 1)
78                 while True:
79                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
80                         break
81                     o = 0
82                     if i > 0:
83                         o += (result[n, 0, i - 1, j] == self.token_conductor).long()
84                     if i < self.height - 1:
85                         o += (result[n, 0, i + 1, j] == self.token_conductor).long()
86                     if j > 0:
87                         o += (result[n, 0, i, j - 1] == self.token_conductor).long()
88                     if j < self.width - 1:
89                         o += (result[n, 0, i, j + 1] == self.token_conductor).long()
90                     if o > 1:
91                         break
92                     result[n, 0, i, j] = self.token_conductor
93                     i += vi
94                     j += vj
95                 if (
96                     result[n, 0] == self.token_conductor
97                 ).long().sum() > self.width and torch.rand(1) < 0.5:
98                     break
99
100             while True:
101                 for _ in range(self.height * self.width):
102                     i = torch.randint(self.height, (1,))
103                     j = torch.randint(self.width, (1,))
104                     v = torch.randint(2, (2,))
105                     vi = v[0] * (v[1] * 2 - 1)
106                     vj = (1 - v[0]) * (v[1] * 2 - 1)
107                     if (
108                         i + vi >= 0
109                         and i + vi < self.height
110                         and j + vj >= 0
111                         and j + vj < self.width
112                         and result[n, 0, i, j] == self.token_conductor
113                         and result[n, 0, i + vi, j + vj] == self.token_conductor
114                     ):
115                         result[n, 0, i, j] = self.token_head
116                         result[n, 0, i + vi, j + vj] = self.token_tail
117                         break
118
119                 if torch.rand(1) < 0.75:
120                     break
121
122         weight = torch.full((1, 1, 3, 3), 1.0)
123
124         # mask = (torch.rand(result[:, 0].size()) < 0.01).long()
125         # rand = torch.randint(4, mask.size())
126         # result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
127
128         # empty->empty
129         # head->tail
130         # tail->conductor
131         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
132
133         for l in range(self.nb_iterations * self.speed - 1):
134             nb_head_neighbors = (
135                 F.conv2d(
136                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
137                     weight=weight,
138                     padding=1,
139                 )
140                 .long()
141                 .squeeze(1)
142             )
143             mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
144                 nb_head_neighbors == 2
145             ).long()
146             result[:, l + 1] = (
147                 (result[:, l] == self.token_empty).long() * self.token_empty
148                 + (result[:, l] == self.token_head).long() * self.token_tail
149                 + (result[:, l] == self.token_tail).long() * self.token_conductor
150                 + (result[:, l] == self.token_conductor).long()
151                 * (
152                     mask_1_or_2_heads * self.token_head
153                     + (1 - mask_1_or_2_heads) * self.token_conductor
154                 )
155             )
156
157         result = result[
158             :, torch.arange(self.nb_iterations, device=result.device) * self.speed
159         ]
160
161         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
162         result = result[i]
163
164         # print(f"{result.size(0)=} {nb=}")
165
166         if result.size(0) < nb:
167             # print(result.size(0))
168             result = torch.cat(
169                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
170             )
171
172         return result[:nb]
173
174     def generate_token_sequences(self, nb):
175         frame_sequences = self.generate_frame_sequences(nb)
176
177         result = []
178
179         for frame_sequence in frame_sequences:
180             a = []
181             if torch.rand(1) < 0.5:
182                 for frame in frame_sequence:
183                     if len(a) > 0:
184                         a.append(torch.tensor([self.token_forward]))
185                     a.append(frame.flatten())
186             else:
187                 for frame in reversed(frame_sequence):
188                     if len(a) > 0:
189                         a.append(torch.tensor([self.token_backward]))
190                     a.append(frame.flatten())
191
192             result.append(torch.cat(a, dim=0)[None, :])
193
194         return torch.cat(result, dim=0)
195
196     ######################################################################
197
198     def frame2img(self, x, scale=15):
199         x = x.reshape(-1, self.height, self.width)
200         m = torch.logical_and(x >= 0, x < 4).long()
201
202         x = self.colors[x * m].permute(0, 3, 1, 2)
203         s = x.shape
204         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
205         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
206
207         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
208         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
209         x = x[:, :, 1:, 1:]
210
211         for n in range(m.size(0)):
212             for i in range(m.size(1)):
213                 for j in range(m.size(2)):
214                     if m[n, i, j] == 0:
215                         for k in range(2, scale - 2):
216                             for l in [0, 1]:
217                                 x[n, :, i * scale + k, j * scale + k - l] = 0
218                                 x[
219                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
220                                 ] = 0
221
222         return x
223
224     def seq2img(self, seq, scale=15):
225         all = [
226             self.frame2img(
227                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
228                 scale,
229             )
230         ]
231
232         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
233
234         t = self.height * self.width
235
236         while t < seq.size(1):
237             direction_tokens = seq[:, t]
238             t += 1
239
240             direction_images = self.colors[
241                 torch.full(
242                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
243                 )
244             ].permute(0, 3, 1, 2)
245
246             for n in range(direction_tokens.size(0)):
247                 if direction_tokens[n] == self.token_forward:
248                     for k in range(scale):
249                         for l in [0, 1]:
250                             direction_images[
251                                 n,
252                                 :,
253                                 (self.height * scale) // 2 - scale // 2 + k - l,
254                                 3 + scale // 2 - abs(k - scale // 2),
255                             ] = 0
256                 elif direction_tokens[n] == self.token_backward:
257                     for k in range(scale):
258                         for l in [0, 1]:
259                             direction_images[
260                                 n,
261                                 :,
262                                 (self.height * scale) // 2 - scale // 2 + k - l,
263                                 3 + abs(k - scale // 2),
264                             ] = 0
265                 else:
266                     for k in range(2, scale - 2):
267                         for l in [0, 1]:
268                             direction_images[
269                                 n,
270                                 :,
271                                 (self.height * scale) // 2 - scale // 2 + k - l,
272                                 k,
273                             ] = 0
274                             direction_images[
275                                 n,
276                                 :,
277                                 (self.height * scale) // 2 - scale // 2 + k - l,
278                                 scale - 1 - k,
279                             ] = 0
280
281             all += [
282                 separator,
283                 direction_images,
284                 separator,
285                 self.frame2img(
286                     seq[:, t : t + self.height * self.width].reshape(
287                         -1, self.height, self.width
288                     ),
289                     scale,
290                 ),
291             ]
292
293             t += self.height * self.width
294
295         return torch.cat(all, dim=3)
296
297     def seq2str(self, seq):
298         result = []
299         for s in seq:
300             result.append("".join([self.token2char[v] for v in s]))
301         return result
302
303     def save_image(self, input, result_dir, filename):
304         img = self.seq2img(input.to("cpu"))
305         image_name = os.path.join(result_dir, filename)
306         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
307
308     def save_quizzes(self, input, result_dir, filename_prefix):
309         self.save_image(input, result_dir, filename_prefix + ".png")
310
311
312 ######################################################################
313
314 if __name__ == "__main__":
315     import time
316
317     wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
318
319     start_time = time.perf_counter()
320     frame_sequences = wireworld.generate_frame_sequences(nb=96)
321     delay = time.perf_counter() - start_time
322     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
323
324     # print(wireworld.seq2str(seq[:4]))
325
326     # for t in range(frame_sequences.size(1)):
327     # img = wireworld.seq2img(frame_sequences[:, t])
328     # torchvision.utils.save_image(
329     # img.float() / 255.0,
330     # f"/tmp/frame_{t:03d}.png",
331     # nrow=8,
332     # padding=6,
333     # pad_value=0,
334     # )
335
336     # m = (torch.rand(seq.size()) < 0.05).long()
337     # seq = (1 - m) * seq + m * 23
338
339     token_sequences = wireworld.generate_token_sequences(32)
340     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
341     # img = wireworld.seq2img(frame_sequences[:60])
342
343     # torchvision.utils.save_image(
344     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
345     # )