3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, tqdm, os
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
20 class Wireworld(problem.Problem):
21 colors = torch.tensor(
38 "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
42 self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
46 self.nb_objects = nb_objects
47 self.nb_walls = nb_walls
49 self.nb_iterations = nb_iterations
51 def direction_tokens(self):
52 return self.token_forward, self.token_backward
54 def generate_frame_sequences(self, nb):
58 range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
60 result.append(self.generate_frame_sequences_hard(100))
61 return torch.cat(result, dim=0)[:nb]
63 def generate_frame_sequences_hard(self, nb):
65 nb_frames = (self.nb_iterations - 1) * self.speed + 1
68 (nb * 4, nb_frames, self.height, self.width),
72 for n in range(result.size(0)):
74 i = torch.randint(self.height, (1,))
75 j = torch.randint(self.width, (1,))
76 v = torch.randint(2, (2,))
77 vi = v[0] * (v[1] * 2 - 1)
78 vj = (1 - v[0]) * (v[1] * 2 - 1)
80 if i < 0 or i >= self.height or j < 0 or j >= self.width:
84 o += (result[n, 0, i - 1, j] == self.token_conductor).long()
85 if i < self.height - 1:
86 o += (result[n, 0, i + 1, j] == self.token_conductor).long()
88 o += (result[n, 0, i, j - 1] == self.token_conductor).long()
89 if j < self.width - 1:
90 o += (result[n, 0, i, j + 1] == self.token_conductor).long()
93 result[n, 0, i, j] = self.token_conductor
97 result[n, 0] == self.token_conductor
98 ).long().sum() > self.width and torch.rand(1) < 0.5:
102 for _ in range(self.height * self.width):
103 i = torch.randint(self.height, (1,))
104 j = torch.randint(self.width, (1,))
105 v = torch.randint(2, (2,))
106 vi = v[0] * (v[1] * 2 - 1)
107 vj = (1 - v[0]) * (v[1] * 2 - 1)
110 and i + vi < self.height
112 and j + vj < self.width
113 and result[n, 0, i, j] == self.token_conductor
114 and result[n, 0, i + vi, j + vj] == self.token_conductor
116 result[n, 0, i, j] = self.token_head
117 result[n, 0, i + vi, j + vj] = self.token_tail
120 # if torch.rand(1) < 0.75:
123 weight = torch.full((1, 1, 3, 3), 1.0)
125 mask = (torch.rand(result[:, 0].size()) < 0.01).long()
126 rand = torch.randint(4, mask.size())
127 result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
132 # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
134 nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
137 for l in range(nb_frames - 1):
138 nb_head_neighbors = (
140 input=(result[:, l] == self.token_head).float()[:, None, :, :],
147 mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
148 nb_head_neighbors == 2
151 (result[:, l] == self.token_empty).long() * self.token_empty
152 + (result[:, l] == self.token_head).long() * self.token_tail
153 + (result[:, l] == self.token_tail).long() * self.token_conductor
154 + (result[:, l] == self.token_conductor).long()
156 mask_1_or_2_heads * self.token_head
157 + (1 - mask_1_or_2_heads) * self.token_conductor
160 pred_nb_heads = nb_heads
162 (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
164 valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
166 result = result[valid]
169 :, torch.arange(self.nb_iterations, device=result.device) * self.speed
172 i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
175 # print(f"{result.size(0)=} {nb=}")
177 if result.size(0) < nb:
178 # print(result.size(0))
180 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
185 def generate_token_sequences(self, nb):
186 frame_sequences = self.generate_frame_sequences(nb)
190 for frame_sequence in frame_sequences:
192 if torch.rand(1) < 0.5:
193 for frame in frame_sequence:
195 a.append(torch.tensor([self.token_forward]))
196 a.append(frame.flatten())
198 for frame in reversed(frame_sequence):
200 a.append(torch.tensor([self.token_backward]))
201 a.append(frame.flatten())
203 result.append(torch.cat(a, dim=0)[None, :])
205 return torch.cat(result, dim=0)
207 ######################################################################
209 def frame2img(self, x, scale=15):
210 x = x.reshape(-1, self.height, self.width)
211 m = torch.logical_and(x >= 0, x < 4).long()
213 x = self.colors[x * m].permute(0, 3, 1, 2)
215 x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
216 x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
218 x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
219 x[:, :, torch.arange(0, x.size(2), scale), :] = 0
222 for n in range(m.size(0)):
223 for i in range(m.size(1)):
224 for j in range(m.size(2)):
226 for k in range(2, scale - 2):
228 x[n, :, i * scale + k, j * scale + k - l] = 0
230 n, :, i * scale + scale - 1 - k, j * scale + k - l
235 def seq2img(self, seq, scale=15):
238 seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
243 separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
245 t = self.height * self.width
247 while t < seq.size(1):
248 direction_tokens = seq[:, t]
251 direction_images = self.colors[
253 (direction_tokens.size(0), self.height * scale - 1, scale), 0
255 ].permute(0, 3, 1, 2)
257 for n in range(direction_tokens.size(0)):
258 if direction_tokens[n] == self.token_forward:
259 for k in range(scale):
264 (self.height * scale) // 2 - scale // 2 + k - l,
265 3 + scale // 2 - abs(k - scale // 2),
267 elif direction_tokens[n] == self.token_backward:
268 for k in range(scale):
273 (self.height * scale) // 2 - scale // 2 + k - l,
274 3 + abs(k - scale // 2),
277 for k in range(2, scale - 2):
282 (self.height * scale) // 2 - scale // 2 + k - l,
288 (self.height * scale) // 2 - scale // 2 + k - l,
297 seq[:, t : t + self.height * self.width].reshape(
298 -1, self.height, self.width
304 t += self.height * self.width
306 return torch.cat(all, dim=3)
308 def seq2str(self, seq):
311 result.append("".join([self.token2char[v] for v in s]))
314 def save_image(self, input, result_dir, filename):
315 img = self.seq2img(input.to("cpu"))
316 image_name = os.path.join(result_dir, filename)
317 torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
319 def save_quizzes(self, input, result_dir, filename_prefix):
320 self.save_image(input, result_dir, filename_prefix + ".png")
323 ######################################################################
325 if __name__ == "__main__":
328 wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
330 start_time = time.perf_counter()
331 frame_sequences = wireworld.generate_frame_sequences(nb=96)
332 delay = time.perf_counter() - start_time
333 print(f"{frame_sequences.size(0)/delay:02f} seq/s")
335 # print(wireworld.seq2str(seq[:4]))
337 for t in range(frame_sequences.size(1)):
338 img = wireworld.seq2img(frame_sequences[:, t])
339 torchvision.utils.save_image(
341 f"/tmp/frame_{t:03d}.png",
347 # m = (torch.rand(seq.size()) < 0.05).long()
348 # seq = (1 - m) * seq + m * 23
350 wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
351 token_sequences = wireworld.generate_token_sequences(32)
352 wireworld.save_quizzes(token_sequences, "/tmp", "seq")
353 # img = wireworld.seq2img(frame_sequences[:60])
355 # torchvision.utils.save_image(
356 # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1