Oups
[picoclvr.git] / snake.py
index eb46a07..8a16f9f 100755 (executable)
--- a/snake.py
+++ b/snake.py
@@ -13,7 +13,7 @@ def generate_sequences(
     nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
 ):
     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
-    nb_prior_visits = torch.zeros(nb, height, width, device=device)
+    world_prior_visits = torch.zeros(nb, height, width, device=device)
 
     # nb x 2
     snake_position = torch.cat(
@@ -70,17 +70,17 @@ def generate_sequences(
         snake_direction = snake_next_direction[i, j]
 
         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
-        sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+        sequences_prior_visits[:, 2 * l] = world_prior_visits[
             i, snake_position[:, 0], snake_position[:, 1]
         ]
         if l < prompt_length:
-            nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+            world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
         sequences[:, 2 * l + 1] = snake_direction
 
         # nb x 2
         snake_position = snake_next_position[i, j]
 
-    return sequences, sequences_prior_visits
+    return sequences, sequences_prior_visits, worlds, world_prior_visits
 
 
 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
@@ -111,35 +111,22 @@ def solver(input, ar_mask):
             # print(f'@2 {i=} {j=}')
 
 
-######################################################################
-
-if __name__ == "__main__":
-    for n in range(16):
-        descr = generate(nb=1, height=12, width=16)
-
-        print(nb_properties(descr, height=12, width=16))
-
-        with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
-            for d in descr:
-                f.write(f"{d}\n\n")
+def seq2str(seq):
+    return "".join(["NESW123456789"[i] for i in seq])
 
-        img = descr2img(descr, height=12, width=16)
-        if img.size(0) == 1:
-            img = F.pad(img, (1, 1, 1, 1), value=64)
 
-        torchvision.utils.save_image(
-            img / 255.0,
-            f"picoclvr_example_{n:02d}.png",
-            padding=1,
-            nrow=4,
-            pad_value=0.8,
-        )
+######################################################################
 
-    import time
+if __name__ == "__main__":
+    train_input, train_prior_visits, _, _ = generate_sequences(
+        nb=20,
+        height=9,
+        width=12,
+        nb_colors=5,
+        length=50,
+        prompt_length=100,
+    )
 
-    start_time = time.perf_counter()
-    descr = generate(nb=1000, height=12, width=16)
-    end_time = time.perf_counter()
-    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+    print([seq2str(s) for s in train_input])
 
 ######################################################################