Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index acecfdd..43d2900 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -8,7 +8,7 @@
 # torch.backends.cuda.matmul.allow_tf23
 # torch.autocast(torch.bfloat16)
 
-import math, sys, argparse, time, tqdm, itertools, os
+import math, sys, argparse, time, tqdm, os
 
 import torch, torchvision
 from torch import nn
@@ -27,7 +27,8 @@ else:
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+    description="An implementation of GPT with cache.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
 parser.add_argument("--task", type=str, default="picoclvr")
@@ -40,7 +41,7 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--nb_epochs", type=int, default=25)
 
-parser.add_argument("--batch_size", type=int, default=25)
+parser.add_argument("--batch_size", type=int, default=None)
 
 parser.add_argument("--nb_train_samples", type=int, default=250000)
 
@@ -128,6 +129,28 @@ if args.seed >= 0:
 
 ######################################################################
 
+default_args = {
+    "picoclvr": {
+        "batch_size": 25,
+    },
+    "mnist": {
+        "batch_size": 10,
+    },
+    "maze": {
+        "batch_size": 25,
+    },
+    "snake": {
+        "batch_size": 20,
+    },
+}
+
+if args.task in default_args:
+    for k, v in default_args[args.task].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
 
 def log_string(s):
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
@@ -149,7 +172,12 @@ for n in vars(args):
 def masked_inplace_autoregression(
     model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
 ):
-    for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+    for input, ar_mask in tqdm.tqdm(
+        zip(input.split(batch_size), ar_mask.split(batch_size)),
+        dynamic_ncols=True,
+        desc="autoregression",
+        total=input.size(0) // batch_size,
+    ):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
@@ -636,9 +664,11 @@ class TaskMaze(Task):
 
 
 def generate_snake_sequences(
-    nb, height, width, nb_colors, length, device=torch.device("cpu")
+    nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
 ):
     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+    nb_prior_visits = torch.zeros(nb, height, width, device=device)
+
     # nb x 2
     snake_position = torch.cat(
         (
@@ -649,6 +679,9 @@ def generate_snake_sequences(
     )
     snake_direction = torch.randint(4, (nb,), device=device)
     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+    sequences_prior_visits = torch.zeros(
+        nb, 2 * length, device=device, dtype=torch.int64
+    )
     i = torch.arange(nb, device=device)  # [:,None]
 
     for l in range(length):
@@ -680,7 +713,10 @@ def generate_snake_sequences(
             ),
         ).float()
         val = (
-            torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device)
+            # The multiplicative factors bias toward moving forward
+            torch.rand_like(val)
+            * val
+            * torch.tensor([[1.0, 2.0, 1.0]], device=device)
         )
 
         # nb
@@ -688,18 +724,47 @@ def generate_snake_sequences(
         snake_direction = snake_next_direction[i, j]
 
         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+        sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+            i, snake_position[:, 0], snake_position[:, 1]
+        ]
+        if l < prompt_length:
+            nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
         sequences[:, 2 * l + 1] = snake_direction
 
         # nb x 2
         snake_position = snake_next_position[i, j]
 
-    return sequences, worlds
+    return sequences, sequences_prior_visits
 
 
 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
 # exit(0)
 
 
+def snake_solver(input, ar_mask):
+    for n in range(input.size(0)):
+        i, j, memory = 0, 0, {}
+        # print(input[n])
+        # print(ar_mask[n])
+        for l in range(input.size(1) // 2):
+            if ar_mask[n, 2 * l] == 1:
+                if memory.get((i, j)) is None:
+                    input[n, 2 * l] = -1
+                else:
+                    input[n, 2 * l] = memory[(i, j)]
+            else:
+                # print(f'@3 {memory=}')
+                if memory.get((i, j)) is None:
+                    memory[(i, j)] = input[n, 2 * l]
+                else:
+                    assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
+            # print(f'@1 {i=} {j=}')
+            d = input[n, 2 * l + 1].item()
+            i += (d + 1) % 2 * (d - 1)
+            j += d % 2 * (d - 2)
+            # print(f'@2 {i=} {j=}')
+
+
 class TaskSnake(Task):
     def __init__(
         self,
@@ -710,18 +775,32 @@ class TaskSnake(Task):
         width,
         nb_colors,
         length,
+        prompt_length,
         device=torch.device("cpu"),
     ):
         self.batch_size = batch_size
         self.height = height
         self.width = width
         self.device = device
+        self.prompt_length = prompt_length
 
-        self.train_input, self.train_worlds = generate_snake_sequences(
-            nb_train_samples, height, width, nb_colors, length, self.device
+        self.train_input, self.train_prior_visits = generate_snake_sequences(
+            nb_train_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
         )
-        self.test_input, self.test_worlds = generate_snake_sequences(
-            nb_test_samples, height, width, nb_colors, length, self.device
+        self.test_input, self.test_prior_visits = generate_snake_sequences(
+            nb_test_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
         )
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
@@ -746,32 +825,44 @@ class TaskSnake(Task):
             t = model.training
             model.eval()
 
-            def compute_nb_correct(input):
+            def compute_nb_correct(input, prior_visits):
                 result = input.clone()
-                i = torch.arange(result.size(1), device=result.device)
-                ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
-                    None, :
-                ].long()
+                i = torch.arange(result.size(1), device=result.device)[None, :]
+                ar_mask = (
+                    torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
+                    .long()
+                    .expand_as(result)
+                )
                 result *= 1 - ar_mask
+
+                # snake_solver(result,ar_mask)
+
                 masked_inplace_autoregression(
                     model, self.batch_size, result, ar_mask, device=self.device
                 )
 
-                nb_total = ar_mask.sum() * input.size(0)
-                nb_correct = ((result == input).long() * ar_mask).sum()
+                nb_total = ((prior_visits > 0) * ar_mask).sum()
+
+                nb_correct = (
+                    (result == input).long() * (prior_visits > 0) * ar_mask
+                ).sum()
 
                 # nb_total = result.size(0)
                 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
 
                 return nb_total, nb_correct
 
-            train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
+            # train_nb_total, train_nb_correct = compute_nb_correct(
+            # self.train_input, self.train_prior_visits
+            # )
 
-            log_string(
-                f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-            )
+            log_string(
+            # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            )
 
-            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
+            test_nb_total, test_nb_correct = compute_nb_correct(
+                self.test_input[:1000], self.test_prior_visits[:1000]
+            )
 
             log_string(
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
@@ -840,6 +931,7 @@ elif args.task == "snake":
         width=args.snake_width,
         nb_colors=args.snake_nb_colors,
         length=args.snake_length,
+        prompt_length=args.snake_length // 2,
         device=device,
     )