Update
[beaver.git] / beaver.py
index b505156..f62c749 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -173,13 +173,21 @@ def compute_perplexity(model, split="train"):
 ######################################################################
 
 
-def oneshot_policy_loss(output, policies, mask):
-    targets = policies.permute(0, 2, 1) * mask.unsqueeze(-1)
-    output = output * mask.unsqueeze(-1)
-    return -(output.log_softmax(-1) * targets).sum() / mask.sum()
+def oneshot_policy_loss(mazes, output, policies, height, width):
+    masks = (mazes == maze.v_empty).unsqueeze(-1)
+    targets = policies.permute(0, 2, 1) * masks
+    output = output * masks
+    return -(output.log_softmax(-1) * targets).sum() / masks.sum()
 
 
-# loss = (output.softmax(-1) - targets).abs().max(-1).values.mean()
+def oneshot_trace_loss(mazes, output, policies, height, width):
+    masks = mazes == maze.v_empty
+    targets = maze.stationary_densities(
+        mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+    ).flatten(-2)
+    targets = targets * masks
+    output = output.squeeze(-1) * masks
+    return (output - targets).abs().sum() / masks.sum()
 
 
 def oneshot(gpt, task):
@@ -198,6 +206,7 @@ def oneshot(gpt, task):
         compute_loss = oneshot_policy_loss
     elif args.oneshot_output == "trace":
         dim_out = 1
+        compute_loss = oneshot_trace_loss
     else:
         raise ValueError(f"{args.oneshot_output=}")
 
@@ -206,7 +215,7 @@ def oneshot(gpt, task):
         nn.ReLU(),
         nn.Linear(args.dim_model, args.dim_model),
         nn.ReLU(),
-        nn.Linear(args.dim_model, 4),
+        nn.Linear(args.dim_model, dim_out),
     ).to(device)
 
     for n_epoch in range(args.nb_epochs):
@@ -214,54 +223,66 @@ def oneshot(gpt, task):
         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
         acc_train_loss, nb_train_samples = 0, 0
-        for input, policies in task.policy_batches(split="train"):
+        for mazes, policies in task.policy_batches(split="train"):
             ####
-            # print(f'{input.size()=} {policies.size()=}')
+            # print(f'{mazes.size()=} {policies.size()=}')
             # s = maze.stationary_densities(
             # exit(0)
             ####
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+            masks = mazes == maze.v_empty
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
 
-            loss = compute_loss(output, policies, mask)
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_train_loss += loss.item() * mazes.size(0)
+            nb_train_samples += mazes.size(0)
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
 
         acc_test_loss, nb_test_samples = 0, 0
-        for input, policies in task.policy_batches(split="test"):
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        for mazes, policies in task.policy_batches(split="test"):
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
-            loss = compute_loss(output, policies, mask)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_test_loss += loss.item() * mazes.size(0)
+            nb_test_samples += mazes.size(0)
 
         log_string(
             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
         )
 
         # -------------------
-        input = task.test_input[:32, : task.height * task.width]
-        targets = task.test_policies[:32].permute(0, 2, 1)
-        output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        mazes = task.test_input[:32, : task.height * task.width]
+        policies = task.test_policies[:32]
+        output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
         output = model(output_gpt)
-        scores = (
-            (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
-        ).float()
+        if args.oneshot_output == "policy":
+            targets = policies.permute(0, 2, 1)
+            scores = (
+                (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+            ).float()
+        elif args.oneshot_output == "trace":
+            targets = maze.stationary_densities(
+                mazes.view(-1, task.height, task.width),
+                policies.view(-1, 4, task.height, task.width),
+            ).flatten(-2)
+            scores = output.flatten(-2)
+        else:
+            raise ValueError(f"{args.oneshot_output=}")
+
         scores = scores.reshape(-1, task.height, task.width)
-        input = input.reshape(-1, task.height, task.width)
+        mazes = mazes.reshape(-1, task.height, task.width)
+        targets = targets.reshape(-1, task.height, task.width)
         maze.save_image(
             os.path.join(
                 args.result_dir,
                 f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
             ),
-            mazes=input,
+            mazes=mazes,
             score_paths=scores,
+            score_truth=targets,
         )
         # -------------------