Update.
[culture.git] / wireworld.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm, os
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 import problem
18
19
20 class Wireworld(problem.Problem):
21     colors = torch.tensor(
22         [
23             [128, 128, 128],
24             [128, 128, 255],
25             [255, 0, 0],
26             [255, 255, 0],
27         ]
28     )
29
30     token_empty = 0
31     token_head = 1
32     token_tail = 2
33     token_conductor = 3
34     token_forward = 4
35     token_backward = 5
36
37     token2char = (
38         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
39     )
40
41     def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4):
42         self.height = height
43         self.width = width
44         self.nb_objects = nb_objects
45         self.nb_walls = nb_walls
46         self.nb_iterations = nb_iterations
47
48     def direction_tokens(self):
49         return self.token_forward, self.token_backward
50
51     def generate_frame_sequences(self, nb):
52         frame_sequences = []
53
54         result = torch.full(
55             (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
56         )
57
58         for n in range(result.size(0)):
59             while True:
60                 i = torch.randint(self.height, (1,))
61                 j = torch.randint(self.width, (1,))
62                 v = torch.randint(2, (2,))
63                 vi = v[0] * (v[1] * 2 - 1)
64                 vj = (1 - v[0]) * (v[1] * 2 - 1)
65                 while True:
66                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
67                         break
68                     result[n, 0, i, j] = self.token_conductor
69                     i += vi
70                     j += vj
71                 if torch.rand(1) < 0.5:
72                     break
73
74         weight = torch.full((1, 1, 3, 3), 1.0)
75
76         mask = (torch.rand(result[:, 0].size()) < 0.01).long()
77         rand = torch.randint(4, mask.size())
78         result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
79
80         # empty->empty
81         # head->tail
82         # tail->conductor
83         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
84
85         for l in range(self.nb_iterations - 1):
86             nb_head_neighbors = (
87                 F.conv2d(
88                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
89                     weight=weight,
90                     padding=1,
91                 )
92                 .long()
93                 .squeeze(1)
94             )
95             mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
96                 nb_head_neighbors == 2
97             ).long()
98             result[:, l + 1] = (
99                 (result[:, l] == self.token_empty).long() * self.token_empty
100                 + (result[:, l] == self.token_head).long() * self.token_tail
101                 + (result[:, l] == self.token_tail).long() * self.token_conductor
102                 + (result[:, l] == self.token_conductor).long()
103                 * (
104                     mask_1_or_2_heads * self.token_head
105                     + (1 - mask_1_or_2_heads) * self.token_conductor
106                 )
107             )
108
109         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
110
111         result = result[i]
112
113         if result.size(0) < nb:
114             # print(result.size(0))
115             result = torch.cat(
116                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
117             )
118
119         return result
120
121     def generate_token_sequences(self, nb):
122         frame_sequences = self.generate_frame_sequences(nb)
123
124         result = []
125
126         for frame_sequence in frame_sequences:
127             a = []
128             if torch.rand(1) < 0.5:
129                 for frame in frame_sequence:
130                     if len(a) > 0:
131                         a.append(torch.tensor([self.token_forward]))
132                     a.append(frame.flatten())
133             else:
134                 for frame in reversed(frame_sequence):
135                     if len(a) > 0:
136                         a.append(torch.tensor([self.token_backward]))
137                     a.append(frame.flatten())
138
139             result.append(torch.cat(a, dim=0)[None, :])
140
141         return torch.cat(result, dim=0)
142
143     ######################################################################
144
145     def frame2img(self, x, scale=15):
146         x = x.reshape(-1, self.height, self.width)
147         m = torch.logical_and(x >= 0, x < 4).long()
148
149         x = self.colors[x * m].permute(0, 3, 1, 2)
150         s = x.shape
151         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
152         x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
153
154         x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
155         x[:, :, torch.arange(0, x.size(2), scale), :] = 0
156         x = x[:, :, 1:, 1:]
157
158         for n in range(m.size(0)):
159             for i in range(m.size(1)):
160                 for j in range(m.size(2)):
161                     if m[n, i, j] == 0:
162                         for k in range(2, scale - 2):
163                             for l in [0, 1]:
164                                 x[n, :, i * scale + k, j * scale + k - l] = 0
165                                 x[
166                                     n, :, i * scale + scale - 1 - k, j * scale + k - l
167                                 ] = 0
168
169         return x
170
171     def seq2img(self, seq, scale=15):
172         all = [
173             self.frame2img(
174                 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
175                 scale,
176             )
177         ]
178
179         separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
180
181         t = self.height * self.width
182
183         while t < seq.size(1):
184             direction_tokens = seq[:, t]
185             t += 1
186
187             direction_images = self.colors[
188                 torch.full(
189                     (direction_tokens.size(0), self.height * scale - 1, scale), 0
190                 )
191             ].permute(0, 3, 1, 2)
192
193             for n in range(direction_tokens.size(0)):
194                 if direction_tokens[n] == self.token_forward:
195                     for k in range(scale):
196                         for l in [0, 1]:
197                             direction_images[
198                                 n,
199                                 :,
200                                 (self.height * scale) // 2 - scale // 2 + k - l,
201                                 3 + scale // 2 - abs(k - scale // 2),
202                             ] = 0
203                 elif direction_tokens[n] == self.token_backward:
204                     for k in range(scale):
205                         for l in [0, 1]:
206                             direction_images[
207                                 n,
208                                 :,
209                                 (self.height * scale) // 2 - scale // 2 + k - l,
210                                 3 + abs(k - scale // 2),
211                             ] = 0
212                 else:
213                     for k in range(2, scale - 2):
214                         for l in [0, 1]:
215                             direction_images[
216                                 n,
217                                 :,
218                                 (self.height * scale) // 2 - scale // 2 + k - l,
219                                 k,
220                             ] = 0
221                             direction_images[
222                                 n,
223                                 :,
224                                 (self.height * scale) // 2 - scale // 2 + k - l,
225                                 scale - 1 - k,
226                             ] = 0
227
228             all += [
229                 separator,
230                 direction_images,
231                 separator,
232                 self.frame2img(
233                     seq[:, t : t + self.height * self.width].reshape(
234                         -1, self.height, self.width
235                     ),
236                     scale,
237                 ),
238             ]
239
240             t += self.height * self.width
241
242         return torch.cat(all, dim=3)
243
244     def seq2str(self, seq):
245         result = []
246         for s in seq:
247             result.append("".join([self.token2char[v] for v in s]))
248         return result
249
250     def save_image(self, input, result_dir, filename):
251         img = self.seq2img(input.to("cpu"))
252         image_name = os.path.join(result_dir, filename)
253         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
254
255     def save_quizzes(self, input, result_dir, filename_prefix):
256         self.save_image(input, result_dir, filename_prefix + ".png")
257
258
259 ######################################################################
260
261 if __name__ == "__main__":
262     import time
263
264     wireworld = Wireworld(height=10, width=15, nb_iterations=4)
265
266     start_time = time.perf_counter()
267     frame_sequences = wireworld.generate_frame_sequences(nb=96)
268     delay = time.perf_counter() - start_time
269     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
270
271     # print(wireworld.seq2str(seq[:4]))
272
273     for t in range(frame_sequences.size(1)):
274         img = wireworld.seq2img(frame_sequences[:, t])
275         torchvision.utils.save_image(
276             img.float() / 255.0,
277             f"/tmp/frame_{t:03d}.png",
278             nrow=8,
279             padding=6,
280             pad_value=0,
281         )
282
283     # m = (torch.rand(seq.size()) < 0.05).long()
284     # seq = (1 - m) * seq + m * 23
285
286     # img = wireworld.seq2img(frame_sequences[:60])
287
288     # torchvision.utils.save_image(
289     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
290     # )