From fdc61b7e50e029aac58b10f377acdce549532f84 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 24 Mar 2024 11:29:15 +0100 Subject: [PATCH] Update. --- escape.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/escape.py b/escape.py index a8b9536..93e3052 100755 --- a/escape.py +++ b/escape.py @@ -204,19 +204,34 @@ def episodes2str( def status_bar(a, r, lr=None): a, r = a.item(), r.item() - sb = "ISNEW"[a] if a >= 0 and a < 5 else "?" - sb = sb + thin_vert + ("- +"[r + 1] if r in {-1, 0, 1} else "?") + sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?" + sb_r = " " + ("- +"[r + 1] if r in {-1, 0, 1} else "?") if lr is not None: lr = lr.item() - sb = sb + thin_vert + +("- +"[lr + 1] if lr in {-1, 0, 1} else "?") - return sb + " " * (states.size(-1) - len(sb)) - - result += ( - vert - + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])]) - + vert - + "\n" - ) + sb_r = sb_r + "/" + ("- +"[lr + 1] if lr in {-1, 0, 1} else "?") + return sb_a + " " * (states.size(-1) - len(sb_a) - len(sb_r)) + sb_r + + if lookahead_rewards is None: + result += ( + vert + + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])]) + + vert + + "\n" + ) + else: + result += ( + vert + + vert.join( + [ + status_bar(a, r, lr) + for a, r, lr in zip( + actions[n], rewards[n], lookahead_rewards[n] + ) + ] + ) + + vert + + "\n" + ) result += hline @@ -232,6 +247,6 @@ def episodes2str( if __name__ == "__main__": nb, height, width, T = 8, 4, 6, 20 states, actions, rewards = generate_episodes(nb, height, width, T) - seq = episodes2seq(states, actions, rewards, lookahead_delta=2) + seq = episodes2seq(states, actions, rewards, lookahead_delta=5) s, a, r, lr = seq2episodes(seq, height, width, lookahead=True) print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True)) -- 2.39.5