Update
[beaver.git] / beaver.py
index b505156..bdc12aa 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -79,11 +79,14 @@ parser.add_argument("--maze_width", type=int, default=21)
 
 parser.add_argument("--maze_nb_walls", type=int, default=15)
 
+##############################
+# one-shot prediction
+
 parser.add_argument("--oneshot", action="store_true", default=False)
 
 parser.add_argument("--oneshot_input", type=str, default="head")
 
-parser.add_argument("--oneshot_output", type=str, default="policy")
+parser.add_argument("--oneshot_output", type=str, default="trace")
 
 ######################################################################
 
@@ -173,13 +176,21 @@ def compute_perplexity(model, split="train"):
 ######################################################################
 
 
-def oneshot_policy_loss(output, policies, mask):
-    targets = policies.permute(0, 2, 1) * mask.unsqueeze(-1)
-    output = output * mask.unsqueeze(-1)
-    return -(output.log_softmax(-1) * targets).sum() / mask.sum()
+def oneshot_policy_loss(mazes, output, policies, height, width):
+    masks = (mazes == maze.v_empty).unsqueeze(-1)
+    targets = policies.permute(0, 2, 1) * masks
+    output = output * masks
+    return -(output.log_softmax(-1) * targets).sum() / masks.sum()
 
 
-# loss = (output.softmax(-1) - targets).abs().max(-1).values.mean()
+def oneshot_trace_loss(mazes, output, policies, height, width):
+    masks = mazes == maze.v_empty
+    targets = maze.stationary_densities(
+        mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+    ).flatten(-2)
+    targets = targets * masks
+    output = output.squeeze(-1) * masks
+    return (output - targets).abs().sum() / masks.sum()
 
 
 def oneshot(gpt, task):
@@ -198,6 +209,7 @@ def oneshot(gpt, task):
         compute_loss = oneshot_policy_loss
     elif args.oneshot_output == "trace":
         dim_out = 1
+        compute_loss = oneshot_trace_loss
     else:
         raise ValueError(f"{args.oneshot_output=}")
 
@@ -206,7 +218,7 @@ def oneshot(gpt, task):
         nn.ReLU(),
         nn.Linear(args.dim_model, args.dim_model),
         nn.ReLU(),
-        nn.Linear(args.dim_model, 4),
+        nn.Linear(args.dim_model, dim_out),
     ).to(device)
 
     for n_epoch in range(args.nb_epochs):
@@ -214,54 +226,60 @@ def oneshot(gpt, task):
         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
         acc_train_loss, nb_train_samples = 0, 0
-        for input, policies in task.policy_batches(split="train"):
-            ####
-            # print(f'{input.size()=} {policies.size()=}')
-            # s = maze.stationary_densities(
-            # exit(0)
-            ####
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        for mazes, policies in task.policy_batches(split="train"):
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
 
-            loss = compute_loss(output, policies, mask)
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_train_loss += loss.item() * mazes.size(0)
+            nb_train_samples += mazes.size(0)
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
 
         acc_test_loss, nb_test_samples = 0, 0
-        for input, policies in task.policy_batches(split="test"):
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        for mazes, policies in task.policy_batches(split="test"):
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
-            loss = compute_loss(output, policies, mask)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_test_loss += loss.item() * mazes.size(0)
+            nb_test_samples += mazes.size(0)
 
         log_string(
             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
         )
 
         # -------------------
-        input = task.test_input[:32, : task.height * task.width]
-        targets = task.test_policies[:32].permute(0, 2, 1)
-        output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        mazes = task.test_input[:32, : task.height * task.width]
+        policies = task.test_policies[:32]
+        output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
         output = model(output_gpt)
-        scores = (
-            (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
-        ).float()
+        if args.oneshot_output == "policy":
+            targets = policies.permute(0, 2, 1)
+            scores = (
+                (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+            ).float()
+        elif args.oneshot_output == "trace":
+            targets = maze.stationary_densities(
+                mazes.view(-1, task.height, task.width),
+                policies.view(-1, 4, task.height, task.width),
+            ).flatten(-2)
+            scores = output
+        else:
+            raise ValueError(f"{args.oneshot_output=}")
+
         scores = scores.reshape(-1, task.height, task.width)
-        input = input.reshape(-1, task.height, task.width)
+        mazes = mazes.reshape(-1, task.height, task.width)
+        targets = targets.reshape(-1, task.height, task.width)
         maze.save_image(
             os.path.join(
                 args.result_dir,
                 f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
             ),
-            mazes=input,
+            mazes=mazes,
             score_paths=scores,
+            score_truth=targets,
         )
         # -------------------
 
@@ -515,12 +533,6 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 ##############################
 
-if args.oneshot:
-    oneshot(model, task)
-    exit(0)
-
-##############################
-
 if nb_epochs_finished >= args.nb_epochs:
     n_epoch = nb_epochs_finished
     train_perplexity = compute_perplexity(model, split="train")
@@ -588,3 +600,8 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     log_string(f"saved checkpoint {checkpoint_name}")
 
 ######################################################################
+
+if args.oneshot:
+    oneshot(model, task)
+
+######################################################################