Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 24 Mar 2024 10:29:15 +0000 (11:29 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 24 Mar 2024 10:29:15 +0000 (11:29 +0100)
escape.py

index a8b9536..93e3052 100755 (executable)
--- a/escape.py
+++ b/escape.py
@@ -204,19 +204,34 @@ def episodes2str(
 
         def status_bar(a, r, lr=None):
             a, r = a.item(), r.item()
-            sb = "ISNEW"[a] if a >= 0 and a < 5 else "?"
-            sb = sb + thin_vert + ("- +"[r + 1] if r in {-1, 0, 1} else "?")
+            sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
+            sb_r = " " + ("- +"[r + 1] if r in {-1, 0, 1} else "?")
             if lr is not None:
                 lr = lr.item()
-                sb = sb + thin_vert + +("- +"[lr + 1] if lr in {-1, 0, 1} else "?")
-            return sb + " " * (states.size(-1) - len(sb))
-
-        result += (
-            vert
-            + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])])
-            + vert
-            + "\n"
-        )
+                sb_r = sb_r + "/" + ("- +"[lr + 1] if lr in {-1, 0, 1} else "?")
+            return sb_a + " " * (states.size(-1) - len(sb_a) - len(sb_r)) + sb_r
+
+        if lookahead_rewards is None:
+            result += (
+                vert
+                + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])])
+                + vert
+                + "\n"
+            )
+        else:
+            result += (
+                vert
+                + vert.join(
+                    [
+                        status_bar(a, r, lr)
+                        for a, r, lr in zip(
+                            actions[n], rewards[n], lookahead_rewards[n]
+                        )
+                    ]
+                )
+                + vert
+                + "\n"
+            )
 
         result += hline
 
@@ -232,6 +247,6 @@ def episodes2str(
 if __name__ == "__main__":
     nb, height, width, T = 8, 4, 6, 20
     states, actions, rewards = generate_episodes(nb, height, width, T)
-    seq = episodes2seq(states, actions, rewards, lookahead_delta=2)
+    seq = episodes2seq(states, actions, rewards, lookahead_delta=5)
     s, a, r, lr = seq2episodes(seq, height, width, lookahead=True)
     print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True))