Update
[beaver.git] / beaver.py
index 6ed9dd2..9f8bc41 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -238,6 +238,11 @@ def oneshot(model, learning_rate_scheduler, task):
     model.eval()
     mazes = task.test_input[:32].clone()
     mazes[:, task.height * task.width :] = 0
+    policies = task.test_policies[:32]
+    targets = maze.stationary_densities(
+        mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+        policies.view(-1, 4, task.height, task.width),
+    ).flatten(-2)
     output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
     output = F.softmax(output, dim=2)
     print(f"{output.size()=}")
@@ -245,13 +250,17 @@ def oneshot(model, learning_rate_scheduler, task):
         -1, task.height, task.width
     )
     mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
-    # targets = targets.reshape(-1, task.height, task.width)
+    targets = targets.reshape(-1, task.height, task.width)
+    paths = task.test_input[:32, task.height * task.width :].reshape(
+        -1, task.height, task.width
+    )
     filename = f"oneshot.png"
     maze.save_image(
         os.path.join(args.result_dir, filename),
         mazes=mazes,
+        # target_paths=paths,
         score_paths=proba_path,
-        score_truth=targets,
+        score_truth=targets,
     )
     log_string(f"wrote {filename}")