X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=wireworld.py;h=8257cadfc6599acb2bc0f4cb72b6245b762897c2;hb=5a77666812f943678094edea26bc17dff8304073;hp=76c00e517b2977bdc4b829b9390220ff07d72114;hpb=aae01e186a959131b446d0365c6b951bacfd71d9;p=culture.git diff --git a/wireworld.py b/wireworld.py index 76c00e5..8257cad 100755 --- a/wireworld.py +++ b/wireworld.py @@ -52,10 +52,20 @@ class Wireworld(problem.Problem): return self.token_forward, self.token_backward def generate_frame_sequences(self, nb): + result = [] + N = 100 + for _ in tqdm.tqdm( + range(0, nb + N, N), dynamic_ncols=True, desc="world generation" + ): + result.append(self.generate_frame_sequences_hard(100)) + return torch.cat(result, dim=0)[:nb] + + def generate_frame_sequences_hard(self, nb): frame_sequences = [] + nb_frames = (self.nb_iterations - 1) * self.speed + 1 result = torch.full( - (nb * 4, self.nb_iterations * self.speed, self.height, self.width), + (nb * 4, nb_frames, self.height, self.width), self.token_empty, ) @@ -107,21 +117,24 @@ class Wireworld(problem.Problem): result[n, 0, i + vi, j + vj] = self.token_tail break - if torch.rand(1) < 0.75: - break + # if torch.rand(1) < 0.75: + break weight = torch.full((1, 1, 3, 3), 1.0) - # mask = (torch.rand(result[:, 0].size()) < 0.01).long() - # rand = torch.randint(4, mask.size()) - # result[:, 0] = mask * rand + (1 - mask) * result[:, 0] + mask = (torch.rand(result[:, 0].size()) < 0.01).long() + rand = torch.randint(4, mask.size()) + result[:, 0] = mask * rand + (1 - mask) * result[:, 0] # empty->empty # head->tail # tail->conductor # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - for l in range(self.nb_iterations * self.speed - 1): + nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) + valid = nb_heads > 0 + + for l in range(nb_frames - 1): nb_head_neighbors = ( F.conv2d( input=(result[:, l] == self.token_head).float()[:, None, :, :], @@ -144,6 +157,13 @@ class Wireworld(problem.Problem): + (1 - mask_1_or_2_heads) * self.token_conductor ) ) + pred_nb_heads = nb_heads + nb_heads = ( + (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) + ) + valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) + + result = result[valid] result = result[ :, torch.arange(self.nb_iterations, device=result.device) * self.speed @@ -152,7 +172,7 @@ class Wireworld(problem.Problem): i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 result = result[i] - print(f"{result.size(0)=} {nb=}") + # print(f"{result.size(0)=} {nb=}") if result.size(0) < nb: # print(result.size(0)) @@ -305,7 +325,7 @@ class Wireworld(problem.Problem): if __name__ == "__main__": import time - wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5) + wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1) start_time = time.perf_counter() frame_sequences = wireworld.generate_frame_sequences(nb=96) @@ -314,19 +334,20 @@ if __name__ == "__main__": # print(wireworld.seq2str(seq[:4])) - # for t in range(frame_sequences.size(1)): - # img = wireworld.seq2img(frame_sequences[:, t]) - # torchvision.utils.save_image( - # img.float() / 255.0, - # f"/tmp/frame_{t:03d}.png", - # nrow=8, - # padding=6, - # pad_value=0, - # ) + for t in range(frame_sequences.size(1)): + img = wireworld.seq2img(frame_sequences[:, t]) + torchvision.utils.save_image( + img.float() / 255.0, + f"/tmp/frame_{t:03d}.png", + nrow=8, + padding=6, + pad_value=0, + ) # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 + wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) token_sequences = wireworld.generate_token_sequences(32) wireworld.save_quizzes(token_sequences, "/tmp", "seq") # img = wireworld.seq2img(frame_sequences[:60])