X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=wireworld.py;h=8257cadfc6599acb2bc0f4cb72b6245b762897c2;hb=ed64a064ef1d8d3e53c7961480ffafdd516ea984;hp=98e2334e5f218bb21add43f982ee0783658c6e68;hpb=2cbce23079a358ae46d3f196d7fab3de42eb28c6;p=culture.git diff --git a/wireworld.py b/wireworld.py index 98e2334..8257cad 100755 --- a/wireworld.py +++ b/wireworld.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Physics(problem.Problem): +class Wireworld(problem.Problem): colors = torch.tensor( [ [128, 128, 128], @@ -52,10 +52,21 @@ class Physics(problem.Problem): return self.token_forward, self.token_backward def generate_frame_sequences(self, nb): + result = [] + N = 100 + for _ in tqdm.tqdm( + range(0, nb + N, N), dynamic_ncols=True, desc="world generation" + ): + result.append(self.generate_frame_sequences_hard(100)) + return torch.cat(result, dim=0)[:nb] + + def generate_frame_sequences_hard(self, nb): frame_sequences = [] + nb_frames = (self.nb_iterations - 1) * self.speed + 1 result = torch.full( - (nb * 100, self.nb_iterations, self.height, self.width), self.token_empty + (nb * 4, nb_frames, self.height, self.width), + self.token_empty, ) for n in range(result.size(0)): @@ -68,12 +79,47 @@ class Physics(problem.Problem): while True: if i < 0 or i >= self.height or j < 0 or j >= self.width: break + o = 0 + if i > 0: + o += (result[n, 0, i - 1, j] == self.token_conductor).long() + if i < self.height - 1: + o += (result[n, 0, i + 1, j] == self.token_conductor).long() + if j > 0: + o += (result[n, 0, i, j - 1] == self.token_conductor).long() + if j < self.width - 1: + o += (result[n, 0, i, j + 1] == self.token_conductor).long() + if o > 1: + break result[n, 0, i, j] = self.token_conductor i += vi j += vj - if torch.rand(1) < 0.5: + if ( + result[n, 0] == self.token_conductor + ).long().sum() > self.width and torch.rand(1) < 0.5: break + while True: + for _ in range(self.height * self.width): + i = torch.randint(self.height, (1,)) + j = torch.randint(self.width, (1,)) + v = torch.randint(2, (2,)) + vi = v[0] * (v[1] * 2 - 1) + vj = (1 - v[0]) * (v[1] * 2 - 1) + if ( + i + vi >= 0 + and i + vi < self.height + and j + vj >= 0 + and j + vj < self.width + and result[n, 0, i, j] == self.token_conductor + and result[n, 0, i + vi, j + vj] == self.token_conductor + ): + result[n, 0, i, j] = self.token_head + result[n, 0, i + vi, j + vj] = self.token_tail + break + + # if torch.rand(1) < 0.75: + break + weight = torch.full((1, 1, 3, 3), 1.0) mask = (torch.rand(result[:, 0].size()) < 0.01).long() @@ -85,7 +131,10 @@ class Physics(problem.Problem): # tail->conductor # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - for l in range(self.nb_iterations - 1): + nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) + valid = nb_heads > 0 + + for l in range(nb_frames - 1): nb_head_neighbors = ( F.conv2d( input=(result[:, l] == self.token_head).float()[:, None, :, :], @@ -108,18 +157,30 @@ class Physics(problem.Problem): + (1 - mask_1_or_2_heads) * self.token_conductor ) ) + pred_nb_heads = nb_heads + nb_heads = ( + (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) + ) + valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) - i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 + result = result[valid] + + result = result[ + :, torch.arange(self.nb_iterations, device=result.device) * self.speed + ] + i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 result = result[i] + # print(f"{result.size(0)=} {nb=}") + if result.size(0) < nb: - print(result.size(0)) + # print(result.size(0)) result = torch.cat( [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 ) - return result + return result[:nb] def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb) @@ -264,17 +325,17 @@ class Physics(problem.Problem): if __name__ == "__main__": import time - sky = Physics(height=10, width=15, speed=1, nb_iterations=100) + wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1) start_time = time.perf_counter() - frame_sequences = sky.generate_frame_sequences(nb=96) + frame_sequences = wireworld.generate_frame_sequences(nb=96) delay = time.perf_counter() - start_time print(f"{frame_sequences.size(0)/delay:02f} seq/s") - # print(sky.seq2str(seq[:4])) + # print(wireworld.seq2str(seq[:4])) for t in range(frame_sequences.size(1)): - img = sky.seq2img(frame_sequences[:, t]) + img = wireworld.seq2img(frame_sequences[:, t]) torchvision.utils.save_image( img.float() / 255.0, f"/tmp/frame_{t:03d}.png", @@ -286,7 +347,10 @@ if __name__ == "__main__": # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 - # img = sky.seq2img(frame_sequences[:60]) + wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) + token_sequences = wireworld.generate_token_sequences(32) + wireworld.save_quizzes(token_sequences, "/tmp", "seq") + # img = wireworld.seq2img(frame_sequences[:60]) # torchvision.utils.save_image( # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1