From 2cbce23079a358ae46d3f196d7fab3de42eb28c6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 29 Jun 2024 16:18:30 +0300 Subject: [PATCH] Update. --- wireworld.py | 293 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100755 wireworld.py diff --git a/wireworld.py b/wireworld.py new file mode 100755 index 0000000..98e2334 --- /dev/null +++ b/wireworld.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, tqdm, os + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +import problem + + +class Physics(problem.Problem): + colors = torch.tensor( + [ + [128, 128, 128], + [128, 128, 255], + [255, 0, 0], + [255, 255, 0], + ] + ) + + token_empty = 0 + token_head = 1 + token_tail = 2 + token_conductor = 3 + token_forward = 4 + token_backward = 5 + + token2char = ( + "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" + ) + + def __init__( + self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4 + ): + self.height = height + self.width = width + self.nb_objects = nb_objects + self.nb_walls = nb_walls + self.speed = speed + self.nb_iterations = nb_iterations + + def direction_tokens(self): + return self.token_forward, self.token_backward + + def generate_frame_sequences(self, nb): + frame_sequences = [] + + result = torch.full( + (nb * 100, self.nb_iterations, self.height, self.width), self.token_empty + ) + + for n in range(result.size(0)): + while True: + i = torch.randint(self.height, (1,)) + j = torch.randint(self.width, (1,)) + v = torch.randint(2, (2,)) + vi = v[0] * (v[1] * 2 - 1) + vj = (1 - v[0]) * (v[1] * 2 - 1) + while True: + if i < 0 or i >= self.height or j < 0 or j >= self.width: + break + result[n, 0, i, j] = self.token_conductor + i += vi + j += vj + if torch.rand(1) < 0.5: + break + + weight = torch.full((1, 1, 3, 3), 1.0) + + mask = (torch.rand(result[:, 0].size()) < 0.01).long() + rand = torch.randint(4, mask.size()) + result[:, 0] = mask * rand + (1 - mask) * result[:, 0] + + # empty->empty + # head->tail + # tail->conductor + # conductor->head if 1 or 2 head in the neighborhood, or remains conductor + + for l in range(self.nb_iterations - 1): + nb_head_neighbors = ( + F.conv2d( + input=(result[:, l] == self.token_head).float()[:, None, :, :], + weight=weight, + padding=1, + ) + .long() + .squeeze(1) + ) + mask_1_or_2_heads = (nb_head_neighbors == 1).long() + ( + nb_head_neighbors == 2 + ).long() + result[:, l + 1] = ( + (result[:, l] == self.token_empty).long() * self.token_empty + + (result[:, l] == self.token_head).long() * self.token_tail + + (result[:, l] == self.token_tail).long() * self.token_conductor + + (result[:, l] == self.token_conductor).long() + * ( + mask_1_or_2_heads * self.token_head + + (1 - mask_1_or_2_heads) * self.token_conductor + ) + ) + + i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 + + result = result[i] + + if result.size(0) < nb: + print(result.size(0)) + result = torch.cat( + [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 + ) + + return result + + def generate_token_sequences(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + + result = [] + + for frame_sequence in frame_sequences: + a = [] + if torch.rand(1) < 0.5: + for frame in frame_sequence: + if len(a) > 0: + a.append(torch.tensor([self.token_forward])) + a.append(frame.flatten()) + else: + for frame in reversed(frame_sequence): + if len(a) > 0: + a.append(torch.tensor([self.token_backward])) + a.append(frame.flatten()) + + result.append(torch.cat(a, dim=0)[None, :]) + + return torch.cat(result, dim=0) + + ###################################################################### + + def frame2img(self, x, scale=15): + x = x.reshape(-1, self.height, self.width) + m = torch.logical_and(x >= 0, x < 4).long() + + x = self.colors[x * m].permute(0, 3, 1, 2) + s = x.shape + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 + x = x[:, :, 1:, 1:] + + for n in range(m.size(0)): + for i in range(m.size(1)): + for j in range(m.size(2)): + if m[n, i, j] == 0: + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 + + return x + + def seq2img(self, seq, scale=15): + all = [ + self.frame2img( + seq[:, : self.height * self.width].reshape(-1, self.height, self.width), + scale, + ) + ] + + separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0) + + t = self.height * self.width + + while t < seq.size(1): + direction_tokens = seq[:, t] + t += 1 + + direction_images = self.colors[ + torch.full( + (direction_tokens.size(0), self.height * scale - 1, scale), 0 + ) + ].permute(0, 3, 1, 2) + + for n in range(direction_tokens.size(0)): + if direction_tokens[n] == self.token_forward: + for k in range(scale): + for l in [0, 1]: + direction_images[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + scale // 2 - abs(k - scale // 2), + ] = 0 + elif direction_tokens[n] == self.token_backward: + for k in range(scale): + for l in [0, 1]: + direction_images[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + abs(k - scale // 2), + ] = 0 + else: + for k in range(2, scale - 2): + for l in [0, 1]: + direction_images[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + k, + ] = 0 + direction_images[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + scale - 1 - k, + ] = 0 + + all += [ + separator, + direction_images, + separator, + self.frame2img( + seq[:, t : t + self.height * self.width].reshape( + -1, self.height, self.width + ), + scale, + ), + ] + + t += self.height * self.width + + return torch.cat(all, dim=3) + + def seq2str(self, seq): + result = [] + for s in seq: + result.append("".join([self.token2char[v] for v in s])) + return result + + def save_image(self, input, result_dir, filename): + img = self.seq2img(input.to("cpu")) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) + + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") + + +###################################################################### + +if __name__ == "__main__": + import time + + sky = Physics(height=10, width=15, speed=1, nb_iterations=100) + + start_time = time.perf_counter() + frame_sequences = sky.generate_frame_sequences(nb=96) + delay = time.perf_counter() - start_time + print(f"{frame_sequences.size(0)/delay:02f} seq/s") + + # print(sky.seq2str(seq[:4])) + + for t in range(frame_sequences.size(1)): + img = sky.seq2img(frame_sequences[:, t]) + torchvision.utils.save_image( + img.float() / 255.0, + f"/tmp/frame_{t:03d}.png", + nrow=8, + padding=6, + pad_value=0, + ) + + # m = (torch.rand(seq.size()) < 0.05).long() + # seq = (1 - m) * seq + m * 23 + + # img = sky.seq2img(frame_sequences[:60]) + + # torchvision.utils.save_image( + # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1 + # ) -- 2.39.5