Update.
[culture.git] / wireworld.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Wireworld(problem.Problem):
21     colors = torch.tensor(
22         [
23             [128, 128, 128],
24             [128, 128, 255],
25             [255, 0, 0],
26             [255, 255, 0],
27         ]
28     )
29
30     token_empty = 0
31     token_head = 1
32     token_tail = 2
33     token_conductor = 3
34     token_forward = 4
35     token_backward = 5
36
37     token2char = (
38         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
39     )
40
41     def __init__(
42         self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
43     ):
44         self.height = height
45         self.width = width
46         self.nb_objects = nb_objects
47         self.nb_walls = nb_walls
48         self.speed = speed
49         self.nb_iterations = nb_iterations
50
51     def direction_tokens(self):
52         return self.token_forward, self.token_backward
53
54     def generate_frame_sequences(self, nb):
55         frame_sequences = []
56
57         result = torch.full(
58             (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
59             self.token_empty,
60         )
61
62         for n in range(result.size(0)):
63             while True:
64                 i = torch.randint(self.height, (1,))
65                 j = torch.randint(self.width, (1,))
66                 v = torch.randint(2, (2,))
67                 vi = v[0] * (v[1] * 2 - 1)
68                 vj = (1 - v[0]) * (v[1] * 2 - 1)
69                 while True:
70                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
71                         break
72                     o = 0
73                     if i > 0:
74                         o += (result[n, 0, i - 1, j] == self.token_conductor).long()
75                     if i < self.height - 1:
76                         o += (result[n, 0, i + 1, j] == self.token_conductor).long()
77                     if j > 0:
78                         o += (result[n, 0, i, j - 1] == self.token_conductor).long()
79                     if j < self.width - 1:
80                         o += (result[n, 0, i, j + 1] == self.token_conductor).long()
81                     if o > 1:
82                         break
83                     result[n, 0, i, j] = self.token_conductor
84                     i += vi
85                     j += vj
86                 if (
87                     result[n, 0] == self.token_conductor
88                 ).long().sum() > self.width and torch.rand(1) < 0.5:
89                     break
90
91             while True:
92                 for _ in range(self.height * self.width):
93                     i = torch.randint(self.height, (1,))
94                     j = torch.randint(self.width, (1,))
95                     v = torch.randint(2, (2,))
96                     vi = v[0] * (v[1] * 2 - 1)
97                     vj = (1 - v[0]) * (v[1] * 2 - 1)
98                     if (
99                         i + vi >= 0
100                         and i + vi < self.height
101                         and j + vj >= 0
102                         and j + vj < self.width
103                         and result[n, 0, i, j] == self.token_conductor
104                         and result[n, 0, i + vi, j + vj] == self.token_conductor
105                     ):
106                         result[n, 0, i, j] = self.token_head
107                         result[n, 0, i + vi, j + vj] = self.token_tail
108                         break
109
110                 if torch.rand(1) < 0.75:
111                     break
112
113         weight = torch.full((1, 1, 3, 3), 1.0)
114
115         # mask = (torch.rand(result[:, 0].size()) < 0.01).long()
116         # rand = torch.randint(4, mask.size())
117         # result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
118
119         # empty->empty
120         # head->tail
121         # tail->conductor
122         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
123
124         for l in range(self.nb_iterations * self.speed - 1):
125             nb_head_neighbors = (
126                 F.conv2d(
127                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
128                     weight=weight,
129                     padding=1,
130                 )
131                 .long()
132                 .squeeze(1)
133             )
134             mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
135                 nb_head_neighbors == 2
136             ).long()
137             result[:, l + 1] = (
138                 (result[:, l] == self.token_empty).long() * self.token_empty
139                 + (result[:, l] == self.token_head).long() * self.token_tail
140                 + (result[:, l] == self.token_tail).long() * self.token_conductor
141                 + (result[:, l] == self.token_conductor).long()
142                 * (
143                     mask_1_or_2_heads * self.token_head
144                     + (1 - mask_1_or_2_heads) * self.token_conductor
145                 )
146             )
147
148         result = result[
149             :, torch.arange(self.nb_iterations, device=result.device) * self.speed
150         ]
151
152         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
153         result = result[i]
154
155         print(f"{result.size(0)=} {nb=}")
156
157         if result.size(0) < nb:
158             # print(result.size(0))
159             result = torch.cat(
160                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
161             )
162
163         return result[:nb]
164
165     def generate_token_sequences(self, nb):
166         frame_sequences = self.generate_frame_sequences(nb)
167
168         result = []
169
170         for frame_sequence in frame_sequences:
171             a = []
172             if torch.rand(1) < 0.5:
173                 for frame in frame_sequence:
174                     if len(a) > 0:
175                         a.append(torch.tensor([self.token_forward]))
176                     a.append(frame.flatten())
177             else:
178                 for frame in reversed(frame_sequence):
179                     if len(a) > 0:
180                         a.append(torch.tensor([self.token_backward]))
181                     a.append(frame.flatten())
182
183             result.append(torch.cat(a, dim=0)[None, :])
184
185         return torch.cat(result, dim=0)
186
187     ######################################################################
188
189     def frame2img(self, x, scale=15):
190         x = x.reshape(-1, self.height, self.width)
191         m = torch.logical_and(x >= 0, x < 4).long()
192
193         x = self.colors[x * m].permute(0, 3, 1, 2)
194         s = x.shape
195         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
196         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
197
198         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
199         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
200         x = x[:, :, 1:, 1:]
201
202         for n in range(m.size(0)):
203             for i in range(m.size(1)):
204                 for j in range(m.size(2)):
205                     if m[n, i, j] == 0:
206                         for k in range(2, scale - 2):
207                             for l in [0, 1]:
208                                 x[n, :, i * scale + k, j * scale + k - l] = 0
209                                 x[
210                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
211                                 ] = 0
212
213         return x
214
215     def seq2img(self, seq, scale=15):
216         all = [
217             self.frame2img(
218                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
219                 scale,
220             )
221         ]
222
223         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
224
225         t = self.height * self.width
226
227         while t < seq.size(1):
228             direction_tokens = seq[:, t]
229             t += 1
230
231             direction_images = self.colors[
232                 torch.full(
233                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
234                 )
235             ].permute(0, 3, 1, 2)
236
237             for n in range(direction_tokens.size(0)):
238                 if direction_tokens[n] == self.token_forward:
239                     for k in range(scale):
240                         for l in [0, 1]:
241                             direction_images[
242                                 n,
243                                 :,
244                                 (self.height * scale) // 2 - scale // 2 + k - l,
245                                 3 + scale // 2 - abs(k - scale // 2),
246                             ] = 0
247                 elif direction_tokens[n] == self.token_backward:
248                     for k in range(scale):
249                         for l in [0, 1]:
250                             direction_images[
251                                 n,
252                                 :,
253                                 (self.height * scale) // 2 - scale // 2 + k - l,
254                                 3 + abs(k - scale // 2),
255                             ] = 0
256                 else:
257                     for k in range(2, scale - 2):
258                         for l in [0, 1]:
259                             direction_images[
260                                 n,
261                                 :,
262                                 (self.height * scale) // 2 - scale // 2 + k - l,
263                                 k,
264                             ] = 0
265                             direction_images[
266                                 n,
267                                 :,
268                                 (self.height * scale) // 2 - scale // 2 + k - l,
269                                 scale - 1 - k,
270                             ] = 0
271
272             all += [
273                 separator,
274                 direction_images,
275                 separator,
276                 self.frame2img(
277                     seq[:, t : t + self.height * self.width].reshape(
278                         -1, self.height, self.width
279                     ),
280                     scale,
281                 ),
282             ]
283
284             t += self.height * self.width
285
286         return torch.cat(all, dim=3)
287
288     def seq2str(self, seq):
289         result = []
290         for s in seq:
291             result.append("".join([self.token2char[v] for v in s]))
292         return result
293
294     def save_image(self, input, result_dir, filename):
295         img = self.seq2img(input.to("cpu"))
296         image_name = os.path.join(result_dir, filename)
297         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
298
299     def save_quizzes(self, input, result_dir, filename_prefix):
300         self.save_image(input, result_dir, filename_prefix + ".png")
301
302
303 ######################################################################
304
305 if __name__ == "__main__":
306     import time
307
308     wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
309
310     start_time = time.perf_counter()
311     frame_sequences = wireworld.generate_frame_sequences(nb=96)
312     delay = time.perf_counter() - start_time
313     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
314
315     # print(wireworld.seq2str(seq[:4]))
316
317     # for t in range(frame_sequences.size(1)):
318     # img = wireworld.seq2img(frame_sequences[:, t])
319     # torchvision.utils.save_image(
320     # img.float() / 255.0,
321     # f"/tmp/frame_{t:03d}.png",
322     # nrow=8,
323     # padding=6,
324     # pad_value=0,
325     # )
326
327     # m = (torch.rand(seq.size()) < 0.05).long()
328     # seq = (1 - m) * seq + m * 23
329
330     token_sequences = wireworld.generate_token_sequences(32)
331     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
332     # img = wireworld.seq2img(frame_sequences[:60])
333
334     # torchvision.utils.save_image(
335     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
336     # )