Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 3db87df..e723866 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -8,7 +8,7 @@
 # torch.backends.cuda.matmul.allow_tf23
 # torch.autocast(torch.bfloat16)
 
-import math, sys, argparse, time, tqdm, itertools, os
+import math, sys, argparse, time, tqdm, os
 
 import torch, torchvision
 from torch import nn
@@ -27,7 +27,8 @@ else:
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+    description="An implementation of GPT with cache.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
 parser.add_argument("--task", type=str, default="picoclvr")
@@ -40,7 +41,7 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--nb_epochs", type=int, default=25)
 
-parser.add_argument("--batch_size", type=int, default=25)
+parser.add_argument("--batch_size", type=int, default=None)
 
 parser.add_argument("--nb_train_samples", type=int, default=250000)
 
@@ -92,6 +93,17 @@ parser.add_argument("--maze_width", type=int, default=21)
 
 parser.add_argument("--maze_nb_walls", type=int, default=15)
 
+##############################
+# Snake options
+
+parser.add_argument("--snake_height", type=int, default=6)
+
+parser.add_argument("--snake_width", type=int, default=8)
+
+parser.add_argument("--snake_nb_colors", type=int, default=3)
+
+parser.add_argument("--snake_length", type=int, default=400)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -117,6 +129,28 @@ if args.seed >= 0:
 
 ######################################################################
 
+default_args = {
+    "picoclvr": {
+        "batch_size": 25,
+    },
+    "mnist": {
+        "batch_size": 10,
+    },
+    "maze": {
+        "batch_size": 25,
+    },
+    "snake": {
+        "batch_size": 20,
+    },
+}
+
+if args.task in default_args:
+    for k, v in default_args[args.task].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
 
 def log_string(s):
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
@@ -488,7 +522,7 @@ class TaskMNIST(Task):
         masked_inplace_autoregression(
             model, self.batch_size, results, ar_mask, device=self.device
         )
-        image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png")
+        image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
             image_name,
@@ -608,7 +642,7 @@ class TaskMaze(Task):
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
 
-            filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+            filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
             maze.save_image(
                 filename,
                 mazes=mazes,
@@ -623,6 +657,84 @@ class TaskMaze(Task):
 
 ######################################################################
 
+
+def generate_snake_sequences(
+    nb, height, width, nb_colors, length, device=torch.device("cpu")
+):
+    worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+    nb_prior_visits = torch.zeros(nb, height, width, device=device)
+
+    # nb x 2
+    snake_position = torch.cat(
+        (
+            torch.randint(height, (nb, 1), device=device),
+            torch.randint(width, (nb, 1), device=device),
+        ),
+        1,
+    )
+    snake_direction = torch.randint(4, (nb,), device=device)
+    sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+    sequences_prior_visits = torch.zeros(
+        nb, 2 * length, device=device, dtype=torch.int64
+    )
+    i = torch.arange(nb, device=device)  # [:,None]
+
+    for l in range(length):
+        # nb x 3
+        snake_next_direction = torch.cat(
+            (
+                (snake_direction[:, None] - 1) % 4,
+                snake_direction[:, None],
+                (snake_direction[:, None] + 1) % 4,
+            ),
+            1,
+        )
+
+        # nb x 3
+        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
+        vw = snake_next_direction % 2 * (snake_next_direction - 2)
+
+        # nb x 3 x 2
+        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
+        snake_next_position = snake_position[:, None, :] + snake_next_speed
+
+        # nb x 3
+        val = torch.logical_and(
+            torch.logical_and(
+                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
+            ),
+            torch.logical_and(
+                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
+            ),
+        ).float()
+        val = (
+            # The multiplicative factors bias toward moving forward
+            torch.rand_like(val)
+            * val
+            * torch.tensor([[1.0, 2.0, 1.0]], device=device)
+        )
+
+        # nb
+        j = val.argmax(1)
+        snake_direction = snake_next_direction[i, j]
+
+        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+        sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+            i, snake_position[:, 0], snake_position[:, 1]
+        ]
+        nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+        sequences[:, 2 * l + 1] = snake_direction
+
+        # nb x 2
+        snake_position = snake_next_position[i, j]
+
+    return sequences, sequences_prior_visits
+
+
+# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
+# exit(0)
+
+
 class TaskSnake(Task):
     def __init__(
         self,
@@ -631,7 +743,8 @@ class TaskSnake(Task):
         batch_size,
         height,
         width,
-        nb_walls,
+        nb_colors,
+        length,
         device=torch.device("cpu"),
     ):
         self.batch_size = batch_size
@@ -639,10 +752,14 @@ class TaskSnake(Task):
         self.width = width
         self.device = device
 
-        # self.train_input = 
-        # self.test_input = 
+        self.train_input, self.train_prior_visits = generate_snake_sequences(
+            nb_train_samples, height, width, nb_colors, length, self.device
+        )
+        self.test_input, self.test_prior_visits = generate_snake_sequences(
+            nb_test_samples, height, width, nb_colors, length, self.device
+        )
 
-        self.nb_codes = max(self.train_input.max(), self.train_input.max()) + 1
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -656,6 +773,54 @@ class TaskSnake(Task):
         ):
             yield batch
 
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input, prior_visits):
+                result = input.clone()
+                i = torch.arange(result.size(1), device=result.device)[None, :]
+                ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0).long()
+                result *= 1 - ar_mask
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = (
+                    (prior_visits > 0) * ar_mask
+                ).sum()
+
+                nb_correct = (
+                    (result == input).long() * (prior_visits > 0) * ar_mask
+                ).sum()
+
+                # nb_total = result.size(0)
+                # nb_correct = ((result - input).abs().sum(1) == 0).sum()
+
+                return nb_total, nb_correct
+
+            train_nb_total, train_nb_correct = compute_nb_correct(
+                self.train_input, self.train_prior_visits
+            )
+
+            log_string(
+                f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            )
+
+            test_nb_total, test_nb_correct = compute_nb_correct(
+                self.test_input, self.test_prior_visits
+            )
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
 
 ######################################################################
 
@@ -708,6 +873,18 @@ elif args.task == "maze":
         device=device,
     )
 
+elif args.task == "snake":
+    task = TaskSnake(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.snake_height,
+        width=args.snake_width,
+        nb_colors=args.snake_nb_colors,
+        length=args.snake_length,
+        device=device,
+    )
+
 else:
     raise ValueError(f"Unknown task {args.task}")