Update
authorFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 20:52:20 +0000 (21:52 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 20:52:20 +0000 (21:52 +0100)
beaver.py
maze.py
tensorstack.py

index b505156..f62c749 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -173,13 +173,21 @@ def compute_perplexity(model, split="train"):
 ######################################################################
 
 
-def oneshot_policy_loss(output, policies, mask):
-    targets = policies.permute(0, 2, 1) * mask.unsqueeze(-1)
-    output = output * mask.unsqueeze(-1)
-    return -(output.log_softmax(-1) * targets).sum() / mask.sum()
+def oneshot_policy_loss(mazes, output, policies, height, width):
+    masks = (mazes == maze.v_empty).unsqueeze(-1)
+    targets = policies.permute(0, 2, 1) * masks
+    output = output * masks
+    return -(output.log_softmax(-1) * targets).sum() / masks.sum()
 
 
-# loss = (output.softmax(-1) - targets).abs().max(-1).values.mean()
+def oneshot_trace_loss(mazes, output, policies, height, width):
+    masks = mazes == maze.v_empty
+    targets = maze.stationary_densities(
+        mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+    ).flatten(-2)
+    targets = targets * masks
+    output = output.squeeze(-1) * masks
+    return (output - targets).abs().sum() / masks.sum()
 
 
 def oneshot(gpt, task):
@@ -198,6 +206,7 @@ def oneshot(gpt, task):
         compute_loss = oneshot_policy_loss
     elif args.oneshot_output == "trace":
         dim_out = 1
+        compute_loss = oneshot_trace_loss
     else:
         raise ValueError(f"{args.oneshot_output=}")
 
@@ -206,7 +215,7 @@ def oneshot(gpt, task):
         nn.ReLU(),
         nn.Linear(args.dim_model, args.dim_model),
         nn.ReLU(),
-        nn.Linear(args.dim_model, 4),
+        nn.Linear(args.dim_model, dim_out),
     ).to(device)
 
     for n_epoch in range(args.nb_epochs):
@@ -214,54 +223,66 @@ def oneshot(gpt, task):
         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
         acc_train_loss, nb_train_samples = 0, 0
-        for input, policies in task.policy_batches(split="train"):
+        for mazes, policies in task.policy_batches(split="train"):
             ####
-            # print(f'{input.size()=} {policies.size()=}')
+            # print(f'{mazes.size()=} {policies.size()=}')
             # s = maze.stationary_densities(
             # exit(0)
             ####
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+            masks = mazes == maze.v_empty
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
 
-            loss = compute_loss(output, policies, mask)
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_train_loss += loss.item() * mazes.size(0)
+            nb_train_samples += mazes.size(0)
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
 
         acc_test_loss, nb_test_samples = 0, 0
-        for input, policies in task.policy_batches(split="test"):
-            mask = input == maze.v_empty
-            output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        for mazes, policies in task.policy_batches(split="test"):
+            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
-            loss = compute_loss(output, policies, mask)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+            loss = compute_loss(mazes, output, policies, task.height, task.width)
+            acc_test_loss += loss.item() * mazes.size(0)
+            nb_test_samples += mazes.size(0)
 
         log_string(
             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
         )
 
         # -------------------
-        input = task.test_input[:32, : task.height * task.width]
-        targets = task.test_policies[:32].permute(0, 2, 1)
-        output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+        mazes = task.test_input[:32, : task.height * task.width]
+        policies = task.test_policies[:32]
+        output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
         output = model(output_gpt)
-        scores = (
-            (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
-        ).float()
+        if args.oneshot_output == "policy":
+            targets = policies.permute(0, 2, 1)
+            scores = (
+                (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+            ).float()
+        elif args.oneshot_output == "trace":
+            targets = maze.stationary_densities(
+                mazes.view(-1, task.height, task.width),
+                policies.view(-1, 4, task.height, task.width),
+            ).flatten(-2)
+            scores = output.flatten(-2)
+        else:
+            raise ValueError(f"{args.oneshot_output=}")
+
         scores = scores.reshape(-1, task.height, task.width)
-        input = input.reshape(-1, task.height, task.width)
+        mazes = mazes.reshape(-1, task.height, task.width)
+        targets = targets.reshape(-1, task.height, task.width)
         maze.save_image(
             os.path.join(
                 args.result_dir,
                 f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
             ),
-            mazes=input,
+            mazes=mazes,
             score_paths=scores,
+            score_truth=targets,
         )
         # -------------------
 
diff --git a/maze.py b/maze.py
index d09e860..81afcd9 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -98,10 +98,10 @@ def compute_policy(walls, goal_i, goal_j):
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]
-    value[1, :, :-1] = distance[:, 1:]
-    value[2, 1:, :] = distance[:-1, :]
-    value[3, :-1, :] = distance[1:, :]
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
 
     proba = (value.min(dim=0)[0][None] == value).float()
     proba = proba / proba.sum(dim=0)[None]
@@ -111,18 +111,19 @@ def compute_policy(walls, goal_i, goal_j):
 
 
 def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
     start = (mazes == v_start).nonzero(as_tuple=True)
-    probas = mazes.new_zeros(mazes.size())
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
     pred_probas = probas.clone()
     probas[start] = 1.0
 
     while not pred_probas.equal(probas):
         pred_probas.copy_(probas)
         probas.zero_()
-        probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
-        probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
-        probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
-        probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
         probas[start] = 1.0
 
     return probas
@@ -211,6 +212,7 @@ def save_image(
     target_paths=None,
     predicted_paths=None,
     score_paths=None,
+    score_truth=None,
     path_correct=None,
 ):
     colors = torch.tensor(
@@ -229,6 +231,17 @@ def save_image(
         colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
     )
 
+    if score_truth is not None:
+        score_truth = score_truth.cpu()
+        c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_truth = (
+            c_score_truth * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+            mazes.unsqueeze(1) == v_empty
+        ) * c_score_truth
+
     imgs = c_mazes.unsqueeze(1)
 
     if target_paths is not None:
index 584c12d..074588e 100755 (executable)
@@ -11,9 +11,9 @@ import sys
 
 
 def exception_hook(exc_type, exc_value, tb):
-    r"""Hacks the call stack message to show all the local variables in
-    case of RuntimeError or ValueError, and prints tensors as shape,
-    dtype and device.
+    r"""Hacks the call stack message to show all the local variables
+    in case of relevant error, and prints tensors as shape, dtype and
+    device.
 
     """
 
@@ -28,7 +28,7 @@ def exception_hook(exc_type, exc_value, tb):
         print(f'  File "{filename}", line {line_no}, in {name}')
         print(open(filename, "r").readlines()[line_no - 1])
 
-        if exc_type in {RuntimeError, ValueError}:
+        if exc_type in {RuntimeError, ValueError, IndexError}:
             for n, v in tb.tb_frame.f_locals.items():
                 print(f"  {n} -> {v}")