Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 16:00:48 +0000 (19:00 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 16:00:48 +0000 (19:00 +0300)
main.py
sky.py
wireworld.py

diff --git a/main.py b/main.py
index 30dcd4d..590bfa1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -224,7 +224,7 @@ assert args.nb_test_samples % args.batch_size == 0
 if args.problem == "sky":
     problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2)
 elif args.problem == "wireworld":
-    problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4)
 else:
     raise ValueError
 
diff --git a/sky.py b/sky.py
index abcd394..6ba3882 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -112,10 +112,13 @@ class Sky(problem.Problem):
                         break
 
                 result = torch.zeros(
-                    self.nb_iterations, self.height, self.width, dtype=torch.int64
+                    self.nb_iterations * self.speed,
+                    self.height,
+                    self.width,
+                    dtype=torch.int64,
                 )
 
-                for l in range(self.nb_iterations):
+                for l in range(self.nb_iterations * self.speed):
                     fine = collision_okay()
                     for n in range(self.nb_birds):
                         c = col[n]
@@ -139,7 +142,11 @@ class Sky(problem.Problem):
                 if fine:
                     break
 
-            frame_sequences.append(result)
+            frame_sequences.append(
+                result[
+                    torch.arange(self.nb_iterations, device=result.device) * self.speed
+                ]
+            )
 
         return frame_sequences
 
index 219d7dd..aff236d 100755 (executable)
@@ -38,11 +38,14 @@ class Wireworld(problem.Problem):
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
-    def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4):
+    def __init__(
+        self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
+    ):
         self.height = height
         self.width = width
         self.nb_objects = nb_objects
         self.nb_walls = nb_walls
+        self.speed = speed
         self.nb_iterations = nb_iterations
 
     def direction_tokens(self):
@@ -82,7 +85,7 @@ class Wireworld(problem.Problem):
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations - 1):
+        for l in range(self.nb_iterations * self.speed - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@ -108,7 +111,9 @@ class Wireworld(problem.Problem):
 
         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
 
-        result = result[i]
+        result = result[
+            torch.arange(self.nb_iterations, device=result.device) * self.speed
+        ]
 
         if result.size(0) < nb:
             # print(result.size(0))
@@ -116,7 +121,7 @@ class Wireworld(problem.Problem):
                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
             )
 
-        return result
+        return result[:nb]
 
     def generate_token_sequences(self, nb):
         frame_sequences = self.generate_frame_sequences(nb)
@@ -261,7 +266,7 @@ class Wireworld(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    wireworld = Wireworld(height=10, width=15, nb_iterations=4)
+    wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)
@@ -270,19 +275,21 @@ if __name__ == "__main__":
 
     # print(wireworld.seq2str(seq[:4]))
 
-    for t in range(frame_sequences.size(1)):
-        img = wireworld.seq2img(frame_sequences[:, t])
-        torchvision.utils.save_image(
-            img.float() / 255.0,
-            f"/tmp/frame_{t:03d}.png",
-            nrow=8,
-            padding=6,
-            pad_value=0,
-        )
+    for t in range(frame_sequences.size(1)):
+    # img = wireworld.seq2img(frame_sequences[:, t])
+    # torchvision.utils.save_image(
+    # img.float() / 255.0,
+    # f"/tmp/frame_{t:03d}.png",
+    # nrow=8,
+    # padding=6,
+    # pad_value=0,
+    # )
 
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
+    token_sequences = wireworld.generate_token_sequences(32)
+    wireworld.save_quizzes(token_sequences, "/tmp", "seq")
     # img = wireworld.seq2img(frame_sequences[:60])
 
     # torchvision.utils.save_image(