Update
[beaver.git] / beaver.py
index a289867..f62c749 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -26,9 +26,7 @@ else:
 
 ######################################################################
 
-parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
 
 parser.add_argument("--log_filename", type=str, default="train.log")
 
@@ -81,6 +79,12 @@ parser.add_argument("--maze_width", type=int, default=21)
 
 parser.add_argument("--maze_nb_walls", type=int, default=15)
 
+parser.add_argument("--oneshot", action="store_true", default=False)
+
+parser.add_argument("--oneshot_input", type=str, default="head")
+
+parser.add_argument("--oneshot_output", type=str, default="policy")
+
 ######################################################################
 
 args = parser.parse_args()
@@ -127,13 +131,11 @@ for n in vars(args):
 
 
 def masked_inplace_autoregression(model, batch_size, input, ar_mask):
-
     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
-            model(
-                mygpt.BracketedSequence(input, 0, i.min())
-            )  # Needed to initialize the model's cache
+            # Needed to initialize the model's cache
+            model(mygpt.BracketedSequence(input, 0, i.min()))
         for s in range(i.min(), i.max() + 1):
             output = model(mygpt.BracketedSequence(input, s, 1)).x
             logits = output[:, s]
@@ -148,6 +150,148 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask):
 ######################################################################
 
 
+def compute_perplexity(model, split="train"):
+    with torch.autograd.no_grad():
+        t = model.training
+        model.eval()
+
+        nb_samples, acc_loss = 0, 0.0
+
+        for input in task.batches(split=split):
+            input = input.to(device)
+
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_loss += loss.item() * input.size(0)
+            nb_samples += input.size(0)
+
+        model.train(t)
+
+        return math.exp(min(100, acc_loss / nb_samples))
+
+
+######################################################################
+
+
+def oneshot_policy_loss(mazes, output, policies, height, width):
+    masks = (mazes == maze.v_empty).unsqueeze(-1)
+    targets = policies.permute(0, 2, 1) * masks
+    output = output * masks
+    return -(output.log_softmax(-1) * targets).sum() / masks.sum()
+
+
+def oneshot_trace_loss(mazes, output, policies, height, width):
+    masks = mazes == maze.v_empty
+    targets = maze.stationary_densities(
+        mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+    ).flatten(-2)
+    targets = targets * masks
+    output = output.squeeze(-1) * masks
+    return (output - targets).abs().sum() / masks.sum()
+
+
+def oneshot(gpt, task):
+    t = gpt.training
+    gpt.eval()
+
+    if args.oneshot_input == "head":
+        dim_in = args.dim_model
+    elif args.oneshot_input == "deep":
+        dim_in = args.dim_model * args.nb_blocks * 2
+    else:
+        raise ValueError(f"{args.oneshot_input=}")
+
+    if args.oneshot_output == "policy":
+        dim_out = 4
+        compute_loss = oneshot_policy_loss
+    elif args.oneshot_output == "trace":
+        dim_out = 1
+        compute_loss = oneshot_trace_loss
+    else:
+        raise ValueError(f"{args.oneshot_output=}")
+
+    model = nn.Sequential(
+        nn.Linear(dim_in, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, dim_out),
+    ).to(device)
+
+    for n_epoch in range(args.nb_epochs):
+        learning_rate = learning_rate_schedule[n_epoch]
+        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+
+        acc_train_loss, nb_train_samples = 0, 0
+        for mazes, policies in task.policy_batches(split="train"):
+            ####
+            # print(f'{mazes.size()=} {policies.size()=}')
+            # s = maze.stationary_densities(
+            # exit(0)
+            ####
+            masks = mazes == maze.v_empty
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+            output = model(output_gpt)
+
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_train_loss += loss.item() * mazes.size(0)
+            nb_train_samples += mazes.size(0)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        acc_test_loss, nb_test_samples = 0, 0
+        for mazes, policies in task.policy_batches(split="test"):
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+            output = model(output_gpt)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_test_loss += loss.item() * mazes.size(0)
+            nb_test_samples += mazes.size(0)
+
+        log_string(
+            f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
+        )
+
+        # -------------------
+        mazes = task.test_input[:32, : task.height * task.width]
+        policies = task.test_policies[:32]
+        output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+        output = model(output_gpt)
+        if args.oneshot_output == "policy":
+            targets = policies.permute(0, 2, 1)
+            scores = (
+                (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+            ).float()
+        elif args.oneshot_output == "trace":
+            targets = maze.stationary_densities(
+                mazes.view(-1, task.height, task.width),
+                policies.view(-1, 4, task.height, task.width),
+            ).flatten(-2)
+            scores = output.flatten(-2)
+        else:
+            raise ValueError(f"{args.oneshot_output=}")
+
+        scores = scores.reshape(-1, task.height, task.width)
+        mazes = mazes.reshape(-1, task.height, task.width)
+        targets = targets.reshape(-1, task.height, task.width)
+        maze.save_image(
+            os.path.join(
+                args.result_dir,
+                f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
+            ),
+            mazes=mazes,
+            score_paths=scores,
+            score_truth=targets,
+        )
+        # -------------------
+
+    gpt.train(t)
+
+
+######################################################################
+
+
 class Task:
     def batches(self, split="train"):
         pass
@@ -187,26 +331,27 @@ class TaskMaze(Task):
         self.width = width
         self.device = device
 
-        mazes_train, paths_train = maze.create_maze_data(
+        train_mazes, train_paths, train_policies = maze.create_maze_data(
             nb_train_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
         )
-        mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
-        self.train_input = self.map2seq(mazes_train, paths_train)
-        self.nb_codes = self.train_input.max() + 1
+        self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
+        self.train_policies = train_policies.flatten(-2).to(device)
 
-        mazes_test, paths_test = maze.create_maze_data(
+        test_mazes, test_paths, test_policies = maze.create_maze_data(
             nb_test_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
         )
-        mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
-        self.test_input = self.map2seq(mazes_test, paths_test)
+        self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
+        self.test_policies = test_policies.flatten(-2).to(device)
+
+        self.nb_codes = self.train_input.max() + 1
 
     def batches(self, split="train", nb_to_use=-1):
         assert split in {"train", "test"}
@@ -218,6 +363,24 @@ class TaskMaze(Task):
         ):
             yield batch
 
+    def policy_batches(self, split="train", nb_to_use=-1):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        policies = self.train_policies if split == "train" else self.test_policies
+        input = input[:, : self.height * self.width]
+        policies = policies * (input != maze.v_wall)[:, None]
+
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+            policies = policies[:nb_to_use]
+
+        for batch in tqdm.tqdm(
+            zip(input.split(self.batch_size), policies.split(self.batch_size)),
+            dynamic_ncols=True,
+            desc=f"epoch-{split}",
+        ):
+            yield batch
+
     def vocabulary_size(self):
         return self.nb_codes
 
@@ -227,6 +390,7 @@ class TaskMaze(Task):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
@@ -256,18 +420,18 @@ class TaskMaze(Task):
             input = self.test_input[:32]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
-
             ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
             maze.save_image(
-                f"result_{n_epoch:04d}.png",
-                mazes,
-                paths,
-                predicted_paths,
-                maze.path_correctness(mazes, predicted_paths),
+                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+                mazes=mazes,
+                target_paths=paths,
+                predicted_paths=predicted_paths,
+                path_correct=maze.path_correctness(mazes, predicted_paths),
             )
 
             model.train(t)
@@ -339,8 +503,6 @@ else:
 
 ######################################################################
 
-nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
-
 token_count = 0
 for input in task.batches(split="train"):
     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
@@ -374,13 +536,28 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 ##############################
 
-nb_samples_seen = 0
+if args.oneshot:
+    oneshot(model, task)
+    exit(0)
 
-if nb_epochs_finished >= nb_epochs:
-    task.produce_results(nb_epochs_finished, model)
+##############################
+
+if nb_epochs_finished >= args.nb_epochs:
+    n_epoch = nb_epochs_finished
+    train_perplexity = compute_perplexity(model, split="train")
+    test_perplexity = compute_perplexity(model, split="test")
+
+    log_string(
+        f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+    )
+
+    task.produce_results(n_epoch, model)
+
+    exit(0)
 
-for n_epoch in range(nb_epochs_finished, nb_epochs):
+##############################
 
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
     log_string(f"learning_rate {learning_rate}")
@@ -392,7 +569,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
     elif args.optim == "adamw":
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
     else:
-        raise ValueError(f"Unknown optimizer {args.optim}.")
+        raise ValueError(f"{args.optim=}")
 
     model.train()
 
@@ -404,37 +581,19 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
-        nb_samples_seen += input.size(0)
 
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
-    with torch.autograd.no_grad():
-
-        model.eval()
-
-        nb_test_samples, acc_test_loss = 0, 0.0
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+    test_perplexity = compute_perplexity(model, split="test")
 
-        for input in task.batches(split="test"):
-            input = input.to(device)
-
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
-            output = model(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
-
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-        log_string(
-            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
-        )
+    log_string(
+        f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+    )
 
-        task.produce_results(n_epoch, model)
+    task.produce_results(n_epoch, model)
 
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,