Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 43d2900..14b1bc3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -31,15 +31,17 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--task", type=str, default="picoclvr")
+parser.add_argument(
+    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack"
+)
 
-parser.add_argument("--log_filename", type=str, default="train.log")
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
 parser.add_argument("--result_dir", type=str, default="results_default")
 
 parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument("--nb_epochs", type=int, default=25)
+parser.add_argument("--nb_epochs", type=int, default=None)
 
 parser.add_argument("--batch_size", type=int, default=None)
 
@@ -100,9 +102,18 @@ parser.add_argument("--snake_height", type=int, default=6)
 
 parser.add_argument("--snake_width", type=int, default=8)
 
-parser.add_argument("--snake_nb_colors", type=int, default=3)
+parser.add_argument("--snake_nb_colors", type=int, default=5)
+
+parser.add_argument("--snake_length", type=int, default=200)
+
+##############################
+# Snake options
+
+parser.add_argument("--stack_nb_steps", type=int, default=100)
 
-parser.add_argument("--snake_length", type=int, default=400)
+parser.add_argument("--stack_nb_stacks", type=int, default=1)
+
+parser.add_argument("--stack_nb_values", type=int, default=10)
 
 ######################################################################
 
@@ -131,16 +142,34 @@ if args.seed >= 0:
 
 default_args = {
     "picoclvr": {
+        "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "mnist": {
+        "nb_epochs": 25,
         "batch_size": 10,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "maze": {
+        "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "snake": {
-        "batch_size": 20,
+        "nb_epochs": 5,
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "stack": {
+        "nb_epochs": 5,
+        "batch_size": 25,
+        "nb_train_samples": 100000,
+        "nb_test_samples": 1000,
     },
 }
 
@@ -169,15 +198,29 @@ for n in vars(args):
 ######################################################################
 
 
+# ra_mask is boolean, with 1s on the values to generate
+
+
 def masked_inplace_autoregression(
-    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
 ):
-    for input, ar_mask in tqdm.tqdm(
-        zip(input.split(batch_size), ar_mask.split(batch_size)),
-        dynamic_ncols=True,
-        desc="autoregression",
-        total=input.size(0) // batch_size,
-    ):
+    # p = logits.softmax(1)
+    # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+    if progress_bar_desc is not None:
+        tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=input.size(0) // batch_size,
+        )
+    for input, ar_mask in batches:
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
@@ -313,6 +356,7 @@ class TaskPicoCLVR(Task):
                 input,
                 ar_masks,
                 forbidden_tokens,
+                progress_bar_desc=None,
                 device=self.device,
             )
             model.train(t)
@@ -492,7 +536,7 @@ class TaskPicoCLVR(Task):
 
         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
-            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
         )
         log_string(f"wrote {image_name}")
 
@@ -602,39 +646,83 @@ class TaskMaze(Task):
 
     def compute_error(self, model, split="train", nb_to_use=-1):
         nb_total, nb_correct = 0, 0
-        for input in task.batches(split, nb_to_use):
+        count = torch.zeros(
+            self.width * self.height,
+            self.width * self.height,
+            device=self.device,
+            dtype=torch.int64,
+        )
+        for input in tqdm.tqdm(
+            task.batches(split, nb_to_use),
+            dynamic_ncols=True,
+            desc=f"test-mazes",
+        ):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
             masked_inplace_autoregression(
-                model, self.batch_size, result, ar_mask, device=self.device
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                progress_bar_desc=None,
+                device=self.device,
             )
             mazes, paths = self.seq2map(result)
-            nb_correct += maze.path_correctness(mazes, paths).long().sum()
+            path_correctness = maze.path_correctness(mazes, paths)
+            nb_correct += path_correctness.long().sum()
             nb_total += mazes.size(0)
 
-        return nb_total, nb_correct
+            optimal_path_lengths = (
+                (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            predicted_path_lengths = (
+                (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            optimal_path_lengths = optimal_path_lengths[path_correctness]
+            predicted_path_lengths = predicted_path_lengths[path_correctness]
+            count[optimal_path_lengths, predicted_path_lengths] += 1
+
+        if count.max() == 0:
+            count = None
+        else:
+            count = count[
+                : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+            ]
+
+        return nb_total, nb_correct, count
 
     def produce_results(self, n_epoch, model):
         with torch.autograd.no_grad():
             t = model.training
             model.eval()
 
-            train_nb_total, train_nb_correct = self.compute_error(
+            train_nb_total, train_nb_correct, count = self.compute_error(
                 model, "train", nb_to_use=1000
             )
             log_string(
                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
             )
 
-            test_nb_total, test_nb_correct = self.compute_error(
+            test_nb_total, test_nb_correct, count = self.compute_error(
                 model, "test", nb_to_use=1000
             )
             log_string(
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
+            if count is not None:
+                proportion_optimal = count.diagonal().sum().float() / count.sum()
+                log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+                with open(
+                    os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+                ) as f:
+                    for i in range(count.size(0)):
+                        for j in range(count.size(1)):
+                            eol = " " if j < count.size(1) - 1 else "\n"
+                            f.write(f"{count[i,j]}{eol}")
+
             input = self.test_input[:48]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
@@ -654,6 +742,7 @@ class TaskMaze(Task):
                 target_paths=paths,
                 predicted_paths=predicted_paths,
                 path_correct=maze.path_correctness(mazes, predicted_paths),
+                path_optimal=maze.path_optimality(paths, predicted_paths),
             )
             log_string(f"wrote {filename}")
 
@@ -663,106 +752,7 @@ class TaskMaze(Task):
 ######################################################################
 
 
-def generate_snake_sequences(
-    nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
-):
-    worlds = torch.randint(nb_colors, (nb, height, width), device=device)
-    nb_prior_visits = torch.zeros(nb, height, width, device=device)
-
-    # nb x 2
-    snake_position = torch.cat(
-        (
-            torch.randint(height, (nb, 1), device=device),
-            torch.randint(width, (nb, 1), device=device),
-        ),
-        1,
-    )
-    snake_direction = torch.randint(4, (nb,), device=device)
-    sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
-    sequences_prior_visits = torch.zeros(
-        nb, 2 * length, device=device, dtype=torch.int64
-    )
-    i = torch.arange(nb, device=device)  # [:,None]
-
-    for l in range(length):
-        # nb x 3
-        snake_next_direction = torch.cat(
-            (
-                (snake_direction[:, None] - 1) % 4,
-                snake_direction[:, None],
-                (snake_direction[:, None] + 1) % 4,
-            ),
-            1,
-        )
-
-        # nb x 3
-        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
-        vw = snake_next_direction % 2 * (snake_next_direction - 2)
-
-        # nb x 3 x 2
-        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
-        snake_next_position = snake_position[:, None, :] + snake_next_speed
-
-        # nb x 3
-        val = torch.logical_and(
-            torch.logical_and(
-                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
-            ),
-            torch.logical_and(
-                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
-            ),
-        ).float()
-        val = (
-            # The multiplicative factors bias toward moving forward
-            torch.rand_like(val)
-            * val
-            * torch.tensor([[1.0, 2.0, 1.0]], device=device)
-        )
-
-        # nb
-        j = val.argmax(1)
-        snake_direction = snake_next_direction[i, j]
-
-        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
-        sequences_prior_visits[:, 2 * l] = nb_prior_visits[
-            i, snake_position[:, 0], snake_position[:, 1]
-        ]
-        if l < prompt_length:
-            nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
-        sequences[:, 2 * l + 1] = snake_direction
-
-        # nb x 2
-        snake_position = snake_next_position[i, j]
-
-    return sequences, sequences_prior_visits
-
-
-# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
-# exit(0)
-
-
-def snake_solver(input, ar_mask):
-    for n in range(input.size(0)):
-        i, j, memory = 0, 0, {}
-        # print(input[n])
-        # print(ar_mask[n])
-        for l in range(input.size(1) // 2):
-            if ar_mask[n, 2 * l] == 1:
-                if memory.get((i, j)) is None:
-                    input[n, 2 * l] = -1
-                else:
-                    input[n, 2 * l] = memory[(i, j)]
-            else:
-                # print(f'@3 {memory=}')
-                if memory.get((i, j)) is None:
-                    memory[(i, j)] = input[n, 2 * l]
-                else:
-                    assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
-            # print(f'@1 {i=} {j=}')
-            d = input[n, 2 * l + 1].item()
-            i += (d + 1) % 2 * (d - 1)
-            j += d % 2 * (d - 2)
-            # print(f'@2 {i=} {j=}')
+import snake
 
 
 class TaskSnake(Task):
@@ -784,7 +774,7 @@ class TaskSnake(Task):
         self.device = device
         self.prompt_length = prompt_length
 
-        self.train_input, self.train_prior_visits = generate_snake_sequences(
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
             nb_train_samples,
             height,
             width,
@@ -793,7 +783,7 @@ class TaskSnake(Task):
             prompt_length,
             self.device,
         )
-        self.test_input, self.test_prior_visits = generate_snake_sequences(
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
             nb_test_samples,
             height,
             width,
@@ -835,7 +825,7 @@ class TaskSnake(Task):
                 )
                 result *= 1 - ar_mask
 
-                # snake_solver(result,ar_mask)
+                # snake.solver(result,ar_mask)
 
                 masked_inplace_autoregression(
                     model, self.batch_size, result, ar_mask, device=self.device
@@ -874,6 +864,93 @@ class TaskSnake(Task):
 ######################################################################
 
 
+import stack
+
+
+class TaskStack(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        nb_steps,
+        nb_stacks,
+        nb_values,
+        device=torch.device("cpu"),
+    ):
+        self.batch_size = batch_size
+        self.nb_steps = nb_steps
+        self.nb_stacks = nb_stacks
+        self.nb_values = nb_values
+        self.device = device
+
+        self.train_input, self.train_stack_counts = stack.generate_sequences(
+            nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        self.test_input, self.test_stack_counts = stack.generate_sequences(
+            nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        mask = self.test_input.clone()
+        stack.remove_poped_values(mask,self.nb_stacks)
+        mask=(mask!=self.test_input)
+        counts = self.test_stack_counts.flatten()[mask.flatten()]
+        counts=F.one_hot(counts).sum(0)
+        log_string(f"stack_count {counts}")
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input):
+                result = input.clone()
+                stack.remove_poped_values(result,self.nb_stacks)
+                ar_mask = (result != input).long()
+                result *= 1 - ar_mask
+
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ar_mask.sum()
+
+                nb_correct = (
+                    (result == input).long() * ar_mask
+                ).sum()
+
+                return nb_total, nb_correct
+
+            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
+
+######################################################################
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -935,6 +1012,17 @@ elif args.task == "snake":
         device=device,
     )
 
+elif args.task == "stack":
+    task = TaskStack(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        nb_steps = args.stack_nb_steps,
+        nb_stacks = args.stack_nb_stacks,
+        nb_values = args.stack_nb_values,
+        device=device,
+    )
+
 else:
     raise ValueError(f"Unknown task {args.task}")
 
@@ -1070,9 +1158,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)