Update.
[culture.git] / wireworld.py
index 98e2334..65b12ad 100755 (executable)
@@ -17,7 +17,7 @@ from torch.nn import functional as F
 import problem
 
 
-class Physics(problem.Problem):
+class Wireworld(problem.Problem):
     colors = torch.tensor(
         [
             [128, 128, 128],
@@ -52,10 +52,20 @@ class Physics(problem.Problem):
         return self.token_forward, self.token_backward
 
     def generate_frame_sequences(self, nb):
+        result = []
+        N = 100
+        for _ in tqdm.tqdm(
+            range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
+        ):
+            result.append(self.generate_frame_sequences_hard(100))
+        return torch.cat(result, dim=0)[:nb]
+
+    def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
 
         result = torch.full(
-            (nb * 100, self.nb_iterations, self.height, self.width), self.token_empty
+            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            self.token_empty,
         )
 
         for n in range(result.size(0)):
@@ -68,10 +78,45 @@ class Physics(problem.Problem):
                 while True:
                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
                         break
+                    o = 0
+                    if i > 0:
+                        o += (result[n, 0, i - 1, j] == self.token_conductor).long()
+                    if i < self.height - 1:
+                        o += (result[n, 0, i + 1, j] == self.token_conductor).long()
+                    if j > 0:
+                        o += (result[n, 0, i, j - 1] == self.token_conductor).long()
+                    if j < self.width - 1:
+                        o += (result[n, 0, i, j + 1] == self.token_conductor).long()
+                    if o > 1:
+                        break
                     result[n, 0, i, j] = self.token_conductor
                     i += vi
                     j += vj
-                if torch.rand(1) < 0.5:
+                if (
+                    result[n, 0] == self.token_conductor
+                ).long().sum() > self.width and torch.rand(1) < 0.5:
+                    break
+
+            while True:
+                for _ in range(self.height * self.width):
+                    i = torch.randint(self.height, (1,))
+                    j = torch.randint(self.width, (1,))
+                    v = torch.randint(2, (2,))
+                    vi = v[0] * (v[1] * 2 - 1)
+                    vj = (1 - v[0]) * (v[1] * 2 - 1)
+                    if (
+                        i + vi >= 0
+                        and i + vi < self.height
+                        and j + vj >= 0
+                        and j + vj < self.width
+                        and result[n, 0, i, j] == self.token_conductor
+                        and result[n, 0, i + vi, j + vj] == self.token_conductor
+                    ):
+                        result[n, 0, i, j] = self.token_head
+                        result[n, 0, i + vi, j + vj] = self.token_tail
+                        break
+
+                if torch.rand(1) < 0.75:
                     break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
@@ -85,7 +130,7 @@ class Physics(problem.Problem):
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations - 1):
+        for l in range(self.nb_iterations * self.speed - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@ -109,17 +154,22 @@ class Physics(problem.Problem):
                 )
             )
 
-        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
+        result = result[
+            :, torch.arange(self.nb_iterations, device=result.device) * self.speed
+        ]
 
+        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
         result = result[i]
 
+        # print(f"{result.size(0)=} {nb=}")
+
         if result.size(0) < nb:
-            print(result.size(0))
+            print(result.size(0))
             result = torch.cat(
                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
             )
 
-        return result
+        return result[:nb]
 
     def generate_token_sequences(self, nb):
         frame_sequences = self.generate_frame_sequences(nb)
@@ -264,17 +314,17 @@ class Physics(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Physics(height=10, width=15, speed=1, nb_iterations=100)
+    wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
 
     start_time = time.perf_counter()
-    frame_sequences = sky.generate_frame_sequences(nb=96)
+    frame_sequences = wireworld.generate_frame_sequences(nb=96)
     delay = time.perf_counter() - start_time
     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
 
-    # print(sky.seq2str(seq[:4]))
+    # print(wireworld.seq2str(seq[:4]))
 
     for t in range(frame_sequences.size(1)):
-        img = sky.seq2img(frame_sequences[:, t])
+        img = wireworld.seq2img(frame_sequences[:, t])
         torchvision.utils.save_image(
             img.float() / 255.0,
             f"/tmp/frame_{t:03d}.png",
@@ -286,7 +336,10 @@ if __name__ == "__main__":
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
-    # img = sky.seq2img(frame_sequences[:60])
+    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+    token_sequences = wireworld.generate_token_sequences(32)
+    wireworld.save_quizzes(token_sequences, "/tmp", "seq")
+    # img = wireworld.seq2img(frame_sequences[:60])
 
     # torchvision.utils.save_image(
     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1