Update
authorFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 21:42:19 +0000 (22:42 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 21:42:19 +0000 (22:42 +0100)
beaver.py

index f62c749..a3a5615 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -229,7 +229,6 @@ def oneshot(gpt, task):
             # s = maze.stationary_densities(
             # exit(0)
             ####
-            masks = mazes == maze.v_empty
             output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
             output = model(output_gpt)
 
@@ -268,7 +267,8 @@ def oneshot(gpt, task):
                 mazes.view(-1, task.height, task.width),
                 policies.view(-1, 4, task.height, task.width),
             ).flatten(-2)
-            scores = output.flatten(-2)
+            scores = output
+            print(scores)
         else:
             raise ValueError(f"{args.oneshot_output=}")