Update
[beaver.git] / beaver.py
index 6ed9dd2..e69f151 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -127,6 +127,8 @@ def log_string(s):
     sys.stdout.flush()
 
 
+log_string(f"cmd {' '.join(sys.argv)}")
+
 for n in vars(args):
     log_string(f"args.{n} {getattr(args, n)}")
 
@@ -238,6 +240,11 @@ def oneshot(model, learning_rate_scheduler, task):
     model.eval()
     mazes = task.test_input[:32].clone()
     mazes[:, task.height * task.width :] = 0
+    policies = task.test_policies[:32]
+    targets = maze.stationary_densities(
+        mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+        policies.view(-1, 4, task.height, task.width),
+    ).flatten(-2)
     output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
     output = F.softmax(output, dim=2)
     print(f"{output.size()=}")
@@ -245,13 +252,17 @@ def oneshot(model, learning_rate_scheduler, task):
         -1, task.height, task.width
     )
     mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
-    # targets = targets.reshape(-1, task.height, task.width)
+    targets = targets.reshape(-1, task.height, task.width)
+    paths = task.test_input[:32, task.height * task.width :].reshape(
+        -1, task.height, task.width
+    )
     filename = f"oneshot.png"
     maze.save_image(
         os.path.join(args.result_dir, filename),
         mazes=mazes,
+        # target_paths=paths,
         score_paths=proba_path,
-        score_truth=targets,
+        score_truth=targets,
     )
     log_string(f"wrote {filename}")