aff236d799ed516a1d6e9c1027a78b8968b58f5b
[culture.git] / wireworld.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Wireworld(problem.Problem):
21     colors = torch.tensor(
22         [
23             [128, 128, 128],
24             [128, 128, 255],
25             [255, 0, 0],
26             [255, 255, 0],
27         ]
28     )
29
30     token_empty = 0
31     token_head = 1
32     token_tail = 2
33     token_conductor = 3
34     token_forward = 4
35     token_backward = 5
36
37     token2char = (
38         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
39     )
40
41     def __init__(
42         self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
43     ):
44         self.height = height
45         self.width = width
46         self.nb_objects = nb_objects
47         self.nb_walls = nb_walls
48         self.speed = speed
49         self.nb_iterations = nb_iterations
50
51     def direction_tokens(self):
52         return self.token_forward, self.token_backward
53
54     def generate_frame_sequences(self, nb):
55         frame_sequences = []
56
57         result = torch.full(
58             (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
59         )
60
61         for n in range(result.size(0)):
62             while True:
63                 i = torch.randint(self.height, (1,))
64                 j = torch.randint(self.width, (1,))
65                 v = torch.randint(2, (2,))
66                 vi = v[0] * (v[1] * 2 - 1)
67                 vj = (1 - v[0]) * (v[1] * 2 - 1)
68                 while True:
69                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
70                         break
71                     result[n, 0, i, j] = self.token_conductor
72                     i += vi
73                     j += vj
74                 if torch.rand(1) < 0.5:
75                     break
76
77         weight = torch.full((1, 1, 3, 3), 1.0)
78
79         mask = (torch.rand(result[:, 0].size()) < 0.01).long()
80         rand = torch.randint(4, mask.size())
81         result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
82
83         # empty->empty
84         # head->tail
85         # tail->conductor
86         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
87
88         for l in range(self.nb_iterations * self.speed - 1):
89             nb_head_neighbors = (
90                 F.conv2d(
91                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
92                     weight=weight,
93                     padding=1,
94                 )
95                 .long()
96                 .squeeze(1)
97             )
98             mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
99                 nb_head_neighbors == 2
100             ).long()
101             result[:, l + 1] = (
102                 (result[:, l] == self.token_empty).long() * self.token_empty
103                 + (result[:, l] == self.token_head).long() * self.token_tail
104                 + (result[:, l] == self.token_tail).long() * self.token_conductor
105                 + (result[:, l] == self.token_conductor).long()
106                 * (
107                     mask_1_or_2_heads * self.token_head
108                     + (1 - mask_1_or_2_heads) * self.token_conductor
109                 )
110             )
111
112         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
113
114         result = result[
115             torch.arange(self.nb_iterations, device=result.device) * self.speed
116         ]
117
118         if result.size(0) < nb:
119             # print(result.size(0))
120             result = torch.cat(
121                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
122             )
123
124         return result[:nb]
125
126     def generate_token_sequences(self, nb):
127         frame_sequences = self.generate_frame_sequences(nb)
128
129         result = []
130
131         for frame_sequence in frame_sequences:
132             a = []
133             if torch.rand(1) < 0.5:
134                 for frame in frame_sequence:
135                     if len(a) > 0:
136                         a.append(torch.tensor([self.token_forward]))
137                     a.append(frame.flatten())
138             else:
139                 for frame in reversed(frame_sequence):
140                     if len(a) > 0:
141                         a.append(torch.tensor([self.token_backward]))
142                     a.append(frame.flatten())
143
144             result.append(torch.cat(a, dim=0)[None, :])
145
146         return torch.cat(result, dim=0)
147
148     ######################################################################
149
150     def frame2img(self, x, scale=15):
151         x = x.reshape(-1, self.height, self.width)
152         m = torch.logical_and(x >= 0, x < 4).long()
153
154         x = self.colors[x * m].permute(0, 3, 1, 2)
155         s = x.shape
156         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
157         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
158
159         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
160         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
161         x = x[:, :, 1:, 1:]
162
163         for n in range(m.size(0)):
164             for i in range(m.size(1)):
165                 for j in range(m.size(2)):
166                     if m[n, i, j] == 0:
167                         for k in range(2, scale - 2):
168                             for l in [0, 1]:
169                                 x[n, :, i * scale + k, j * scale + k - l] = 0
170                                 x[
171                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
172                                 ] = 0
173
174         return x
175
176     def seq2img(self, seq, scale=15):
177         all = [
178             self.frame2img(
179                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
180                 scale,
181             )
182         ]
183
184         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
185
186         t = self.height * self.width
187
188         while t < seq.size(1):
189             direction_tokens = seq[:, t]
190             t += 1
191
192             direction_images = self.colors[
193                 torch.full(
194                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
195                 )
196             ].permute(0, 3, 1, 2)
197
198             for n in range(direction_tokens.size(0)):
199                 if direction_tokens[n] == self.token_forward:
200                     for k in range(scale):
201                         for l in [0, 1]:
202                             direction_images[
203                                 n,
204                                 :,
205                                 (self.height * scale) // 2 - scale // 2 + k - l,
206                                 3 + scale // 2 - abs(k - scale // 2),
207                             ] = 0
208                 elif direction_tokens[n] == self.token_backward:
209                     for k in range(scale):
210                         for l in [0, 1]:
211                             direction_images[
212                                 n,
213                                 :,
214                                 (self.height * scale) // 2 - scale // 2 + k - l,
215                                 3 + abs(k - scale // 2),
216                             ] = 0
217                 else:
218                     for k in range(2, scale - 2):
219                         for l in [0, 1]:
220                             direction_images[
221                                 n,
222                                 :,
223                                 (self.height * scale) // 2 - scale // 2 + k - l,
224                                 k,
225                             ] = 0
226                             direction_images[
227                                 n,
228                                 :,
229                                 (self.height * scale) // 2 - scale // 2 + k - l,
230                                 scale - 1 - k,
231                             ] = 0
232
233             all += [
234                 separator,
235                 direction_images,
236                 separator,
237                 self.frame2img(
238                     seq[:, t : t + self.height * self.width].reshape(
239                         -1, self.height, self.width
240                     ),
241                     scale,
242                 ),
243             ]
244
245             t += self.height * self.width
246
247         return torch.cat(all, dim=3)
248
249     def seq2str(self, seq):
250         result = []
251         for s in seq:
252             result.append("".join([self.token2char[v] for v in s]))
253         return result
254
255     def save_image(self, input, result_dir, filename):
256         img = self.seq2img(input.to("cpu"))
257         image_name = os.path.join(result_dir, filename)
258         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
259
260     def save_quizzes(self, input, result_dir, filename_prefix):
261         self.save_image(input, result_dir, filename_prefix + ".png")
262
263
264 ######################################################################
265
266 if __name__ == "__main__":
267     import time
268
269     wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1)
270
271     start_time = time.perf_counter()
272     frame_sequences = wireworld.generate_frame_sequences(nb=96)
273     delay = time.perf_counter() - start_time
274     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
275
276     # print(wireworld.seq2str(seq[:4]))
277
278     # for t in range(frame_sequences.size(1)):
279     # img = wireworld.seq2img(frame_sequences[:, t])
280     # torchvision.utils.save_image(
281     # img.float() / 255.0,
282     # f"/tmp/frame_{t:03d}.png",
283     # nrow=8,
284     # padding=6,
285     # pad_value=0,
286     # )
287
288     # m = (torch.rand(seq.size()) < 0.05).long()
289     # seq = (1 - m) * seq + m * 23
290
291     token_sequences = wireworld.generate_token_sequences(32)
292     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
293     # img = wireworld.seq2img(frame_sequences[:60])
294
295     # torchvision.utils.save_image(
296     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
297     # )