Update.
[culture.git] / wireworld.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Wireworld(problem.Problem):
21     colors = torch.tensor(
22         [
23             [128, 128, 128],
24             [128, 128, 255],
25             [255, 0, 0],
26             [255, 255, 0],
27         ]
28     )
29
30     token_empty = 0
31     token_head = 1
32     token_tail = 2
33     token_conductor = 3
34     token_forward = 4
35     token_backward = 5
36
37     token2char = (
38         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
39     )
40
41     def __init__(
42         self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
43     ):
44         self.height = height
45         self.width = width
46         self.nb_objects = nb_objects
47         self.nb_walls = nb_walls
48         self.speed = speed
49         self.nb_iterations = nb_iterations
50
51     def direction_tokens(self):
52         return self.token_forward, self.token_backward
53
54     def generate_frame_sequences(self, nb):
55         result = []
56         N = 100
57         for _ in tqdm.tqdm(
58             range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
59         ):
60             result.append(self.generate_frame_sequences_hard(100))
61         return torch.cat(result, dim=0)[:nb]
62
63     def generate_frame_sequences_hard(self, nb):
64         frame_sequences = []
65         nb_frames = (self.nb_iterations - 1) * self.speed + 1
66
67         result = torch.full(
68             (nb * 4, nb_frames, self.height, self.width),
69             self.token_empty,
70         )
71
72         for n in range(result.size(0)):
73             while True:
74                 i = torch.randint(self.height, (1,))
75                 j = torch.randint(self.width, (1,))
76                 v = torch.randint(2, (2,))
77                 vi = v[0] * (v[1] * 2 - 1)
78                 vj = (1 - v[0]) * (v[1] * 2 - 1)
79                 while True:
80                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
81                         break
82                     o = 0
83                     if i > 0:
84                         o += (result[n, 0, i - 1, j] == self.token_conductor).long()
85                     if i < self.height - 1:
86                         o += (result[n, 0, i + 1, j] == self.token_conductor).long()
87                     if j > 0:
88                         o += (result[n, 0, i, j - 1] == self.token_conductor).long()
89                     if j < self.width - 1:
90                         o += (result[n, 0, i, j + 1] == self.token_conductor).long()
91                     if o > 1:
92                         break
93                     result[n, 0, i, j] = self.token_conductor
94                     i += vi
95                     j += vj
96                 if (
97                     result[n, 0] == self.token_conductor
98                 ).long().sum() > self.width and torch.rand(1) < 0.5:
99                     break
100
101             while True:
102                 for _ in range(self.height * self.width):
103                     i = torch.randint(self.height, (1,))
104                     j = torch.randint(self.width, (1,))
105                     v = torch.randint(2, (2,))
106                     vi = v[0] * (v[1] * 2 - 1)
107                     vj = (1 - v[0]) * (v[1] * 2 - 1)
108                     if (
109                         i + vi >= 0
110                         and i + vi < self.height
111                         and j + vj >= 0
112                         and j + vj < self.width
113                         and result[n, 0, i, j] == self.token_conductor
114                         and result[n, 0, i + vi, j + vj] == self.token_conductor
115                     ):
116                         result[n, 0, i, j] = self.token_head
117                         result[n, 0, i + vi, j + vj] = self.token_tail
118                         break
119
120                 # if torch.rand(1) < 0.75:
121                 break
122
123         weight = torch.full((1, 1, 3, 3), 1.0)
124
125         mask = (torch.rand(result[:, 0].size()) < 0.01).long()
126         rand = torch.randint(4, mask.size())
127         result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
128
129         # empty->empty
130         # head->tail
131         # tail->conductor
132         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
133
134         nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
135         valid = nb_heads > 0
136
137         for l in range(nb_frames - 1):
138             nb_head_neighbors = (
139                 F.conv2d(
140                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
141                     weight=weight,
142                     padding=1,
143                 )
144                 .long()
145                 .squeeze(1)
146             )
147             mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
148                 nb_head_neighbors == 2
149             ).long()
150             result[:, l + 1] = (
151                 (result[:, l] == self.token_empty).long() * self.token_empty
152                 + (result[:, l] == self.token_head).long() * self.token_tail
153                 + (result[:, l] == self.token_tail).long() * self.token_conductor
154                 + (result[:, l] == self.token_conductor).long()
155                 * (
156                     mask_1_or_2_heads * self.token_head
157                     + (1 - mask_1_or_2_heads) * self.token_conductor
158                 )
159             )
160             pred_nb_heads = nb_heads
161             nb_heads = (
162                 (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
163             )
164             valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
165
166         result = result[valid]
167
168         result = result[
169             :, torch.arange(self.nb_iterations, device=result.device) * self.speed
170         ]
171
172         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
173         result = result[i]
174
175         # print(f"{result.size(0)=} {nb=}")
176
177         if result.size(0) < nb:
178             # print(result.size(0))
179             result = torch.cat(
180                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
181             )
182
183         return result[:nb]
184
185     def generate_token_sequences(self, nb):
186         frame_sequences = self.generate_frame_sequences(nb)
187
188         result = []
189
190         for frame_sequence in frame_sequences:
191             a = []
192             if torch.rand(1) < 0.5:
193                 for frame in frame_sequence:
194                     if len(a) > 0:
195                         a.append(torch.tensor([self.token_forward]))
196                     a.append(frame.flatten())
197             else:
198                 for frame in reversed(frame_sequence):
199                     if len(a) > 0:
200                         a.append(torch.tensor([self.token_backward]))
201                     a.append(frame.flatten())
202
203             result.append(torch.cat(a, dim=0)[None, :])
204
205         return torch.cat(result, dim=0)
206
207     ######################################################################
208
209     def frame2img(self, x, scale=15):
210         x = x.reshape(-1, self.height, self.width)
211         m = torch.logical_and(x >= 0, x < 4).long()
212
213         x = self.colors[x * m].permute(0, 3, 1, 2)
214         s = x.shape
215         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
216         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
217
218         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
219         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
220         x = x[:, :, 1:, 1:]
221
222         for n in range(m.size(0)):
223             for i in range(m.size(1)):
224                 for j in range(m.size(2)):
225                     if m[n, i, j] == 0:
226                         for k in range(2, scale - 2):
227                             for l in [0, 1]:
228                                 x[n, :, i * scale + k, j * scale + k - l] = 0
229                                 x[
230                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
231                                 ] = 0
232
233         return x
234
235     def seq2img(self, seq, scale=15):
236         all = [
237             self.frame2img(
238                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
239                 scale,
240             )
241         ]
242
243         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
244
245         t = self.height * self.width
246
247         while t < seq.size(1):
248             direction_tokens = seq[:, t]
249             t += 1
250
251             direction_images = self.colors[
252                 torch.full(
253                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
254                 )
255             ].permute(0, 3, 1, 2)
256
257             for n in range(direction_tokens.size(0)):
258                 if direction_tokens[n] == self.token_forward:
259                     for k in range(scale):
260                         for l in [0, 1]:
261                             direction_images[
262                                 n,
263                                 :,
264                                 (self.height * scale) // 2 - scale // 2 + k - l,
265                                 3 + scale // 2 - abs(k - scale // 2),
266                             ] = 0
267                 elif direction_tokens[n] == self.token_backward:
268                     for k in range(scale):
269                         for l in [0, 1]:
270                             direction_images[
271                                 n,
272                                 :,
273                                 (self.height * scale) // 2 - scale // 2 + k - l,
274                                 3 + abs(k - scale // 2),
275                             ] = 0
276                 else:
277                     for k in range(2, scale - 2):
278                         for l in [0, 1]:
279                             direction_images[
280                                 n,
281                                 :,
282                                 (self.height * scale) // 2 - scale // 2 + k - l,
283                                 k,
284                             ] = 0
285                             direction_images[
286                                 n,
287                                 :,
288                                 (self.height * scale) // 2 - scale // 2 + k - l,
289                                 scale - 1 - k,
290                             ] = 0
291
292             all += [
293                 separator,
294                 direction_images,
295                 separator,
296                 self.frame2img(
297                     seq[:, t : t + self.height * self.width].reshape(
298                         -1, self.height, self.width
299                     ),
300                     scale,
301                 ),
302             ]
303
304             t += self.height * self.width
305
306         return torch.cat(all, dim=3)
307
308     def seq2str(self, seq):
309         result = []
310         for s in seq:
311             result.append("".join([self.token2char[v] for v in s]))
312         return result
313
314     def save_image(self, input, result_dir, filename):
315         img = self.seq2img(input.to("cpu"))
316         image_name = os.path.join(result_dir, filename)
317         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
318
319     def save_quizzes(self, input, result_dir, filename_prefix):
320         self.save_image(input, result_dir, filename_prefix + ".png")
321
322
323 ######################################################################
324
325 if __name__ == "__main__":
326     import time
327
328     wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
329
330     start_time = time.perf_counter()
331     frame_sequences = wireworld.generate_frame_sequences(nb=96)
332     delay = time.perf_counter() - start_time
333     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
334
335     # print(wireworld.seq2str(seq[:4]))
336
337     for t in range(frame_sequences.size(1)):
338         img = wireworld.seq2img(frame_sequences[:, t])
339         torchvision.utils.save_image(
340             img.float() / 255.0,
341             f"/tmp/frame_{t:03d}.png",
342             nrow=8,
343             padding=6,
344             pad_value=0,
345         )
346
347     # m = (torch.rand(seq.size()) < 0.05).long()
348     # seq = (1 - m) * seq + m * 23
349
350     wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
351     token_sequences = wireworld.generate_token_sequences(32)
352     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
353     # img = wireworld.seq2img(frame_sequences[:60])
354
355     # torchvision.utils.save_image(
356     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
357     # )