Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index f8e451b..45bddb7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -8,7 +8,7 @@
 # torch.backends.cuda.matmul.allow_tf23
 # torch.autocast(torch.bfloat16)
 
-import math, sys, argparse, time, tqdm, itertools, os
+import math, sys, argparse, time, tqdm, os
 
 import torch, torchvision
 from torch import nn
@@ -27,20 +27,23 @@ else:
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+    description="An implementation of GPT with cache.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--task", type=str, default="picoclvr")
+parser.add_argument(
+    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+)
 
-parser.add_argument("--log_filename", type=str, default="train.log")
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
 parser.add_argument("--result_dir", type=str, default="results_default")
 
 parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument("--nb_epochs", type=int, default=25)
+parser.add_argument("--nb_epochs", type=int, default=None)
 
-parser.add_argument("--batch_size", type=int, default=25)
+parser.add_argument("--batch_size", type=int, default=None)
 
 parser.add_argument("--nb_train_samples", type=int, default=250000)
 
@@ -92,6 +95,17 @@ parser.add_argument("--maze_width", type=int, default=21)
 
 parser.add_argument("--maze_nb_walls", type=int, default=15)
 
+##############################
+# Snake options
+
+parser.add_argument("--snake_height", type=int, default=6)
+
+parser.add_argument("--snake_width", type=int, default=8)
+
+parser.add_argument("--snake_nb_colors", type=int, default=5)
+
+parser.add_argument("--snake_length", type=int, default=200)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -117,6 +131,32 @@ if args.seed >= 0:
 
 ######################################################################
 
+default_args = {
+    "picoclvr": {
+        "nb_epochs": 25,
+        "batch_size": 25,
+    },
+    "mnist": {
+        "nb_epochs": 25,
+        "batch_size": 10,
+    },
+    "maze": {
+        "nb_epochs": 25,
+        "batch_size": 25,
+    },
+    "snake": {
+        "nb_epochs": 5,
+        "batch_size": 25,
+    },
+}
+
+if args.task in default_args:
+    for k, v in default_args[args.task].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
 
 def log_string(s):
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
@@ -135,10 +175,29 @@ for n in vars(args):
 ######################################################################
 
 
+# ra_mask is boolean, with 1s on the values to generate
+
+
 def masked_inplace_autoregression(
-    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
 ):
-    for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+    # p = logits.softmax(1)
+    # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+    if progress_bar_desc is not None:
+        tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=input.size(0) // batch_size,
+        )
+    for input, ar_mask in batches:
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
@@ -274,6 +333,7 @@ class TaskPicoCLVR(Task):
                 input,
                 ar_masks,
                 forbidden_tokens,
+                progress_bar_desc=None,
                 device=self.device,
             )
             model.train(t)
@@ -453,7 +513,7 @@ class TaskPicoCLVR(Task):
 
         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
-            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
         )
         log_string(f"wrote {image_name}")
 
@@ -488,7 +548,7 @@ class TaskMNIST(Task):
         masked_inplace_autoregression(
             model, self.batch_size, results, ar_mask, device=self.device
         )
-        image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png")
+        image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
             image_name,
@@ -563,39 +623,83 @@ class TaskMaze(Task):
 
     def compute_error(self, model, split="train", nb_to_use=-1):
         nb_total, nb_correct = 0, 0
-        for input in task.batches(split, nb_to_use):
+        count = torch.zeros(
+            self.width * self.height,
+            self.width * self.height,
+            device=self.device,
+            dtype=torch.int64,
+        )
+        for input in tqdm.tqdm(
+            task.batches(split, nb_to_use),
+            dynamic_ncols=True,
+            desc=f"test-mazes",
+        ):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
             masked_inplace_autoregression(
-                model, self.batch_size, result, ar_mask, device=self.device
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                progress_bar_desc=None,
+                device=self.device,
             )
             mazes, paths = self.seq2map(result)
-            nb_correct += maze.path_correctness(mazes, paths).long().sum()
+            path_correctness = maze.path_correctness(mazes, paths)
+            nb_correct += path_correctness.long().sum()
             nb_total += mazes.size(0)
 
-        return nb_total, nb_correct
+            optimal_path_lengths = (
+                (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            predicted_path_lengths = (
+                (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            optimal_path_lengths = optimal_path_lengths[path_correctness]
+            predicted_path_lengths = predicted_path_lengths[path_correctness]
+            count[optimal_path_lengths, predicted_path_lengths] += 1
+
+        if count.max() == 0:
+            count = None
+        else:
+            count = count[
+                : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+            ]
+
+        return nb_total, nb_correct, count
 
     def produce_results(self, n_epoch, model):
         with torch.autograd.no_grad():
             t = model.training
             model.eval()
 
-            train_nb_total, train_nb_correct = self.compute_error(
+            train_nb_total, train_nb_correct, count = self.compute_error(
                 model, "train", nb_to_use=1000
             )
             log_string(
                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
             )
 
-            test_nb_total, test_nb_correct = self.compute_error(
+            test_nb_total, test_nb_correct, count = self.compute_error(
                 model, "test", nb_to_use=1000
             )
             log_string(
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
+            if count is not None:
+                proportion_optimal = count.diagonal().sum().float() / count.sum()
+                log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+                with open(
+                    os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+                ) as f:
+                    for i in range(count.size(0)):
+                        for j in range(count.size(1)):
+                            eol = " " if j < count.size(1) - 1 else "\n"
+                            f.write(f"{count[i,j]}{eol}")
+
             input = self.test_input[:48]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
@@ -608,13 +712,14 @@ class TaskMaze(Task):
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
 
-            filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+            filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
             maze.save_image(
                 filename,
                 mazes=mazes,
                 target_paths=paths,
                 predicted_paths=predicted_paths,
                 path_correct=maze.path_correctness(mazes, predicted_paths),
+                path_optimal=maze.path_optimality(paths, predicted_paths),
             )
             log_string(f"wrote {filename}")
 
@@ -624,69 +729,8 @@ class TaskMaze(Task):
 ######################################################################
 
 
-def generate_snake_sequences(
-    nb, height, width, nb_colors, length, device=torch.device("cpu")
-):
-    world = torch.randint(nb_colors, (nb, height, width), device=device)
-    # nb x 2
-    snake_position = torch.cat(
-        (
-            torch.randint(height, (nb, 1), device=device),
-            torch.randint(width, (nb, 1), device=device),
-        ),
-        1,
-    )
-    snake_direction = torch.randint(4, (nb, 1), device=device)
-    result = torch.empty(nb, 2*length, device=device, dtype=torch.int64)
-    count = torch.arange(nb, device=device)  # [:,None]
-
-    for l in range(length):
-        # nb x 3
-        snake_next_direction = torch.cat(
-            (
-                (snake_direction - 1) % 4,
-                snake_direction,
-                (snake_direction + 1) % 4,
-            ),
-            1,
-        )
-
-        # nb x 3
-        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
-        vw = snake_next_direction % 2 * (snake_next_direction - 2)
-
-        # nb x 3 x 2
-        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
-        snake_next_position = snake_position[:, None, :] + snake_next_speed
-
-        # nb x 3
-        val = torch.logical_and(
-            torch.logical_and(
-                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
-            ),
-            torch.logical_and(
-                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
-            ),
-        ).float()
-        val = torch.rand_like(val) * val * torch.tensor([[1.,4.,1.]], device=device)
+import snake
 
-        # nb
-        i = torch.arange(val.size(0), device=device)
-        j = val.argmax(1)
-
-        # nb x 1
-        snake_direction = snake_next_direction[i[:, None], j[:, None]]
-
-        result[:, 2*l] = world[count, snake_position[:, 0], snake_position[:, 1]]
-        result[:, 2*l+1] = snake_direction[:,0]
-
-        # nb x 2
-        snake_position = snake_next_position[i[:, None], j[:, None]].squeeze(1)
-
-    return result
-
-generate_snake_sequences(nb=2, height=4, width=5, nb_colors=3, length=10)
-exit(0)
 
 class TaskSnake(Task):
     def __init__(
@@ -698,18 +742,32 @@ class TaskSnake(Task):
         width,
         nb_colors,
         length,
+        prompt_length,
         device=torch.device("cpu"),
     ):
         self.batch_size = batch_size
         self.height = height
         self.width = width
         self.device = device
+        self.prompt_length = prompt_length
 
-        self.train_input = generate_snake_sequences(
-            nb_train_samples, height, width, nb_colors, length, self.device
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
+            nb_train_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
         )
-        self.test_input = generate_snake_sequences(
-            nb_test_samples, height, width, nb_colors, length, self.device
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
+            nb_test_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
         )
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
@@ -729,6 +787,56 @@ class TaskSnake(Task):
     def vocabulary_size(self):
         return self.nb_codes
 
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input, prior_visits):
+                result = input.clone()
+                i = torch.arange(result.size(1), device=result.device)[None, :]
+                ar_mask = (
+                    torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
+                    .long()
+                    .expand_as(result)
+                )
+                result *= 1 - ar_mask
+
+                # snake.solver(result,ar_mask)
+
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ((prior_visits > 0) * ar_mask).sum()
+
+                nb_correct = (
+                    (result == input).long() * (prior_visits > 0) * ar_mask
+                ).sum()
+
+                # nb_total = result.size(0)
+                # nb_correct = ((result - input).abs().sum(1) == 0).sum()
+
+                return nb_total, nb_correct
+
+            # train_nb_total, train_nb_correct = compute_nb_correct(
+            # self.train_input, self.train_prior_visits
+            # )
+
+            # log_string(
+            # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            # )
+
+            test_nb_total, test_nb_correct = compute_nb_correct(
+                self.test_input[:1000], self.test_prior_visits[:1000]
+            )
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
 
 ######################################################################
 
@@ -786,10 +894,11 @@ elif args.task == "snake":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
-        height=6,
-        width=8,
-        nb_colors=5,
-        length=100,
+        height=args.snake_height,
+        width=args.snake_width,
+        nb_colors=args.snake_nb_colors,
+        length=args.snake_length,
+        prompt_length=args.snake_length // 2,
         device=device,
     )
 
@@ -928,9 +1037,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)