Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 13:24:23 +0000 (16:24 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 13:24:23 +0000 (16:24 +0300)
main.py
wireworld.py

diff --git a/main.py b/main.py
index 1565499..b88847e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, quizz_machine
+import sky, wireworld, quizz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -37,7 +37,7 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+parser.add_argument("--log_filename", type=str, default="train.log")
 
 parser.add_argument("--result_dir", type=str, default=None)
 
@@ -79,6 +79,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--problem", type=str, default="sky")
+
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--nb_models_for_generation", type=int, default=1)
@@ -219,8 +221,15 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
-quizz_machine = quizz_machine.QuizzMachine(
+if args.problem=="sky":
     problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),
+elif args.problem="wireworld":
+    problem=wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+else:
+    raise ValueError
+
+quizz_machine = quizz_machine.QuizzMachine(
+    problem=problem,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
index 98e2334..219d7dd 100755 (executable)
@@ -17,7 +17,7 @@ from torch.nn import functional as F
 import problem
 
 
-class Physics(problem.Problem):
+class Wireworld(problem.Problem):
     colors = torch.tensor(
         [
             [128, 128, 128],
@@ -38,14 +38,11 @@ class Physics(problem.Problem):
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
-    def __init__(
-        self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
-    ):
+    def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4):
         self.height = height
         self.width = width
         self.nb_objects = nb_objects
         self.nb_walls = nb_walls
-        self.speed = speed
         self.nb_iterations = nb_iterations
 
     def direction_tokens(self):
@@ -55,7 +52,7 @@ class Physics(problem.Problem):
         frame_sequences = []
 
         result = torch.full(
-            (nb * 100, self.nb_iterations, self.height, self.width), self.token_empty
+            (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
         )
 
         for n in range(result.size(0)):
@@ -114,7 +111,7 @@ class Physics(problem.Problem):
         result = result[i]
 
         if result.size(0) < nb:
-            print(result.size(0))
+            print(result.size(0))
             result = torch.cat(
                 [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
             )
@@ -264,17 +261,17 @@ class Physics(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Physics(height=10, width=15, speed=1, nb_iterations=100)
+    wireworld = Wireworld(height=10, width=15, nb_iterations=4)
 
     start_time = time.perf_counter()
-    frame_sequences = sky.generate_frame_sequences(nb=96)
+    frame_sequences = wireworld.generate_frame_sequences(nb=96)
     delay = time.perf_counter() - start_time
     print(f"{frame_sequences.size(0)/delay:02f} seq/s")
 
-    # print(sky.seq2str(seq[:4]))
+    # print(wireworld.seq2str(seq[:4]))
 
     for t in range(frame_sequences.size(1)):
-        img = sky.seq2img(frame_sequences[:, t])
+        img = wireworld.seq2img(frame_sequences[:, t])
         torchvision.utils.save_image(
             img.float() / 255.0,
             f"/tmp/frame_{t:03d}.png",
@@ -286,7 +283,7 @@ if __name__ == "__main__":
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
-    # img = sky.seq2img(frame_sequences[:60])
+    # img = wireworld.seq2img(frame_sequences[:60])
 
     # torchvision.utils.save_image(
     # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1