Merge branch 'dev'
[culture.git] / wireworld.py
index 76c00e5..8257cad 100755 (executable)
@@ -52,10 +52,20 @@ class Wireworld(problem.Problem):
         return self.token_forward, self.token_backward
 
     def generate_frame_sequences(self, nb):
         return self.token_forward, self.token_backward
 
     def generate_frame_sequences(self, nb):
+        result = []
+        N = 100
+        for _ in tqdm.tqdm(
+            range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
+        ):
+            result.append(self.generate_frame_sequences_hard(100))
+        return torch.cat(result, dim=0)[:nb]
+
+    def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
         frame_sequences = []
+        nb_frames = (self.nb_iterations - 1) * self.speed + 1
 
         result = torch.full(
 
         result = torch.full(
-            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            (nb * 4, nb_frames, self.height, self.width),
             self.token_empty,
         )
 
             self.token_empty,
         )
 
@@ -107,21 +117,24 @@ class Wireworld(problem.Problem):
                         result[n, 0, i + vi, j + vj] = self.token_tail
                         break
 
                         result[n, 0, i + vi, j + vj] = self.token_tail
                         break
 
-                if torch.rand(1) < 0.75:
-                    break
+                if torch.rand(1) < 0.75:
+                break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
-        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
-        rand = torch.randint(4, mask.size())
-        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
+        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
+        rand = torch.randint(4, mask.size())
+        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
 
         # empty->empty
         # head->tail
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
 
         # empty->empty
         # head->tail
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations * self.speed - 1):
+        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+        valid = nb_heads > 0
+
+        for l in range(nb_frames - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@ -144,6 +157,13 @@ class Wireworld(problem.Problem):
                     + (1 - mask_1_or_2_heads) * self.token_conductor
                 )
             )
                     + (1 - mask_1_or_2_heads) * self.token_conductor
                 )
             )
+            pred_nb_heads = nb_heads
+            nb_heads = (
+                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+            )
+            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+        result = result[valid]
 
         result = result[
             :, torch.arange(self.nb_iterations, device=result.device) * self.speed
 
         result = result[
             :, torch.arange(self.nb_iterations, device=result.device) * self.speed
@@ -152,7 +172,7 @@ class Wireworld(problem.Problem):
         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
         result = result[i]
 
         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
         result = result[i]
 
-        print(f"{result.size(0)=} {nb=}")
+        print(f"{result.size(0)=} {nb=}")
 
         if result.size(0) < nb:
             # print(result.size(0))
 
         if result.size(0) < nb:
             # print(result.size(0))
@@ -305,7 +325,7 @@ class Wireworld(problem.Problem):
 if __name__ == "__main__":
     import time
 
 if __name__ == "__main__":
     import time
 
-    wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
+    wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)
@@ -314,19 +334,20 @@ if __name__ == "__main__":
 
     # print(wireworld.seq2str(seq[:4]))
 
 
     # print(wireworld.seq2str(seq[:4]))
 
-    for t in range(frame_sequences.size(1)):
-    # img = wireworld.seq2img(frame_sequences[:, t])
-    # torchvision.utils.save_image(
-    # img.float() / 255.0,
-    # f"/tmp/frame_{t:03d}.png",
-    # nrow=8,
-    # padding=6,
-    # pad_value=0,
-    # )
+    for t in range(frame_sequences.size(1)):
+        img = wireworld.seq2img(frame_sequences[:, t])
+        torchvision.utils.save_image(
+            img.float() / 255.0,
+            f"/tmp/frame_{t:03d}.png",
+            nrow=8,
+            padding=6,
+            pad_value=0,
+        )
 
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
 
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
+    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
     token_sequences = wireworld.generate_token_sequences(32)
     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
     # img = wireworld.seq2img(frame_sequences[:60])
     token_sequences = wireworld.generate_token_sequences(32)
     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
     # img = wireworld.seq2img(frame_sequences[:60])