Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 20:14:38 +0000 (23:14 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 20:14:38 +0000 (23:14 +0300)
main.py
sky.py
wireworld.py

diff --git a/main.py b/main.py
index 590bfa1..b62b4c0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -222,7 +222,7 @@ assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 if args.problem == "sky":
-    problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2)
+    problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3)
 elif args.problem == "wireworld":
     problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4)
 else:
diff --git a/sky.py b/sky.py
index 6ba3882..1164185 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -118,8 +118,14 @@ class Sky(problem.Problem):
                     dtype=torch.int64,
                 )
 
+                fine = torch.empty(self.nb_iterations * self.speed)
+
+                t_to_keep = (
+                    torch.arange(self.nb_iterations, device=result.device) * self.speed
+                )
+
                 for l in range(self.nb_iterations * self.speed):
-                    fine = collision_okay()
+                    fine[l] = collision_okay()
                     for n in range(self.nb_birds):
                         c = col[n]
                         result[l, i[n], j[n]] = c
@@ -139,14 +145,13 @@ class Sky(problem.Problem):
                         i[n] += vi[n]
                         j[n] += vj[n]
 
-                if fine:
+                result = result[t_to_keep]
+                fine = fine[t_to_keep]
+
+                if fine[-1]:
                     break
 
-            frame_sequences.append(
-                result[
-                    torch.arange(self.nb_iterations, device=result.device) * self.speed
-                ]
-            )
+            frame_sequences.append(result)
 
         return frame_sequences
 
@@ -296,7 +301,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
 
     start_time = time.perf_counter()
     token_sequences = sky.generate_token_sequences(nb=64)
index aff236d..76c00e5 100755 (executable)
@@ -55,7 +55,8 @@ class Wireworld(problem.Problem):
         frame_sequences = []
 
         result = torch.full(
-            (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
+            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            self.token_empty,
         )
 
         for n in range(result.size(0)):
@@ -68,17 +69,52 @@ class Wireworld(problem.Problem):
                 while True:
                     if i < 0 or i >= self.height or j < 0 or j >= self.width:
                         break
+                    o = 0
+                    if i > 0:
+                        o += (result[n, 0, i - 1, j] == self.token_conductor).long()
+                    if i < self.height - 1:
+                        o += (result[n, 0, i + 1, j] == self.token_conductor).long()
+                    if j > 0:
+                        o += (result[n, 0, i, j - 1] == self.token_conductor).long()
+                    if j < self.width - 1:
+                        o += (result[n, 0, i, j + 1] == self.token_conductor).long()
+                    if o > 1:
+                        break
                     result[n, 0, i, j] = self.token_conductor
                     i += vi
                     j += vj
-                if torch.rand(1) < 0.5:
+                if (
+                    result[n, 0] == self.token_conductor
+                ).long().sum() > self.width and torch.rand(1) < 0.5:
+                    break
+
+            while True:
+                for _ in range(self.height * self.width):
+                    i = torch.randint(self.height, (1,))
+                    j = torch.randint(self.width, (1,))
+                    v = torch.randint(2, (2,))
+                    vi = v[0] * (v[1] * 2 - 1)
+                    vj = (1 - v[0]) * (v[1] * 2 - 1)
+                    if (
+                        i + vi >= 0
+                        and i + vi < self.height
+                        and j + vj >= 0
+                        and j + vj < self.width
+                        and result[n, 0, i, j] == self.token_conductor
+                        and result[n, 0, i + vi, j + vj] == self.token_conductor
+                    ):
+                        result[n, 0, i, j] = self.token_head
+                        result[n, 0, i + vi, j + vj] = self.token_tail
+                        break
+
+                if torch.rand(1) < 0.75:
                     break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
-        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
-        rand = torch.randint(4, mask.size())
-        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
+        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
+        rand = torch.randint(4, mask.size())
+        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
 
         # empty->empty
         # head->tail
@@ -109,12 +145,15 @@ class Wireworld(problem.Problem):
                 )
             )
 
-        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
-
         result = result[
-            torch.arange(self.nb_iterations, device=result.device) * self.speed
+            :, torch.arange(self.nb_iterations, device=result.device) * self.speed
         ]
 
+        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
+        result = result[i]
+
+        print(f"{result.size(0)=} {nb=}")
+
         if result.size(0) < nb:
             # print(result.size(0))
             result = torch.cat(
@@ -266,7 +305,7 @@ class Wireworld(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1)
+    wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)