Update
authorFrançois Fleuret <francois@fleuret.org>
Fri, 7 Apr 2023 12:26:46 +0000 (14:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 7 Apr 2023 12:26:46 +0000 (14:26 +0200)
beaver.py

index 6e1eaf4..9f8bc41 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -238,6 +238,11 @@ def oneshot(model, learning_rate_scheduler, task):
     model.eval()
     mazes = task.test_input[:32].clone()
     mazes[:, task.height * task.width :] = 0
+    policies = task.test_policies[:32]
+    targets = maze.stationary_densities(
+        mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+        policies.view(-1, 4, task.height, task.width),
+    ).flatten(-2)
     output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
     output = F.softmax(output, dim=2)
     print(f"{output.size()=}")
@@ -245,6 +250,7 @@ def oneshot(model, learning_rate_scheduler, task):
         -1, task.height, task.width
     )
     mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+    targets = targets.reshape(-1, task.height, task.width)
     paths = task.test_input[:32, task.height * task.width :].reshape(
         -1, task.height, task.width
     )
@@ -252,9 +258,9 @@ def oneshot(model, learning_rate_scheduler, task):
     maze.save_image(
         os.path.join(args.result_dir, filename),
         mazes=mazes,
-        target_paths=paths,
+        target_paths=paths,
         score_paths=proba_path,
-        score_truth=targets,
+        score_truth=targets,
     )
     log_string(f"wrote {filename}")