Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 24c21f5..acecfdd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -92,6 +92,17 @@ parser.add_argument("--maze_width", type=int, default=21)
 
 parser.add_argument("--maze_nb_walls", type=int, default=15)
 
+##############################
+# Snake options
+
+parser.add_argument("--snake_height", type=int, default=6)
+
+parser.add_argument("--snake_width", type=int, default=8)
+
+parser.add_argument("--snake_nb_colors", type=int, default=3)
+
+parser.add_argument("--snake_length", type=int, default=400)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -488,7 +499,7 @@ class TaskMNIST(Task):
         masked_inplace_autoregression(
             model, self.batch_size, results, ar_mask, device=self.device
         )
-        image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png")
+        image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
             image_name,
@@ -526,7 +537,7 @@ class TaskMaze(Task):
         self.width = width
         self.device = device
 
-        train_mazes, train_paths, train_policies = maze.create_maze_data(
+        train_mazes, train_paths, _ = maze.create_maze_data(
             nb_train_samples,
             height=height,
             width=width,
@@ -534,9 +545,8 @@ class TaskMaze(Task):
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
         )
         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
-        self.train_policies = train_policies.flatten(-2).to(device)
 
-        test_mazes, test_paths, test_policies = maze.create_maze_data(
+        test_mazes, test_paths, _ = maze.create_maze_data(
             nb_test_samples,
             height=height,
             width=width,
@@ -544,9 +554,8 @@ class TaskMaze(Task):
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
         )
         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
-        self.test_policies = test_policies.flatten(-2).to(device)
 
-        self.nb_codes = self.train_input.max() + 1
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -560,26 +569,6 @@ class TaskMaze(Task):
         ):
             yield batch
 
-    def policy_batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        policies = self.train_policies if split == "train" else self.test_policies
-        input = input[:, : self.height * self.width]
-        policies = policies * (input != maze.v_wall)[:, None]
-
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-            policies = policies[:nb_to_use]
-
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            zip(input.split(self.batch_size), policies.split(self.batch_size)),
-            dynamic_ncols=True,
-            desc=desc,
-        ):
-            yield batch
-
     def vocabulary_size(self):
         return self.nb_codes
 
@@ -630,7 +619,7 @@ class TaskMaze(Task):
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
 
-            filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+            filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
             maze.save_image(
                 filename,
                 mazes=mazes,
@@ -646,6 +635,154 @@ class TaskMaze(Task):
 ######################################################################
 
 
+def generate_snake_sequences(
+    nb, height, width, nb_colors, length, device=torch.device("cpu")
+):
+    worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+    # nb x 2
+    snake_position = torch.cat(
+        (
+            torch.randint(height, (nb, 1), device=device),
+            torch.randint(width, (nb, 1), device=device),
+        ),
+        1,
+    )
+    snake_direction = torch.randint(4, (nb,), device=device)
+    sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+    i = torch.arange(nb, device=device)  # [:,None]
+
+    for l in range(length):
+        # nb x 3
+        snake_next_direction = torch.cat(
+            (
+                (snake_direction[:, None] - 1) % 4,
+                snake_direction[:, None],
+                (snake_direction[:, None] + 1) % 4,
+            ),
+            1,
+        )
+
+        # nb x 3
+        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
+        vw = snake_next_direction % 2 * (snake_next_direction - 2)
+
+        # nb x 3 x 2
+        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
+        snake_next_position = snake_position[:, None, :] + snake_next_speed
+
+        # nb x 3
+        val = torch.logical_and(
+            torch.logical_and(
+                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
+            ),
+            torch.logical_and(
+                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
+            ),
+        ).float()
+        val = (
+            torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device)
+        )
+
+        # nb
+        j = val.argmax(1)
+        snake_direction = snake_next_direction[i, j]
+
+        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+        sequences[:, 2 * l + 1] = snake_direction
+
+        # nb x 2
+        snake_position = snake_next_position[i, j]
+
+    return sequences, worlds
+
+
+# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
+# exit(0)
+
+
+class TaskSnake(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        nb_colors,
+        length,
+        device=torch.device("cpu"),
+    ):
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+        self.device = device
+
+        self.train_input, self.train_worlds = generate_snake_sequences(
+            nb_train_samples, height, width, nb_colors, length, self.device
+        )
+        self.test_input, self.test_worlds = generate_snake_sequences(
+            nb_test_samples, height, width, nb_colors, length, self.device
+        )
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input):
+                result = input.clone()
+                i = torch.arange(result.size(1), device=result.device)
+                ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
+                    None, :
+                ].long()
+                result *= 1 - ar_mask
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ar_mask.sum() * input.size(0)
+                nb_correct = ((result == input).long() * ar_mask).sum()
+
+                # nb_total = result.size(0)
+                # nb_correct = ((result - input).abs().sum(1) == 0).sum()
+
+                return nb_total, nb_correct
+
+            train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
+
+            log_string(
+                f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            )
+
+            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
+
+######################################################################
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -694,6 +831,18 @@ elif args.task == "maze":
         device=device,
     )
 
+elif args.task == "snake":
+    task = TaskSnake(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.snake_height,
+        width=args.snake_width,
+        nb_colors=args.snake_nb_colors,
+        length=args.snake_length,
+        device=device,
+    )
+
 else:
     raise ValueError(f"Unknown task {args.task}")