def status_bar(a, r, lr=None):
a, r = a.item(), r.item()
- sb = "ISNEW"[a] if a >= 0 and a < 5 else "?"
- sb = sb + thin_vert + ("- +"[r + 1] if r in {-1, 0, 1} else "?")
+ sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
+ sb_r = " " + ("- +"[r + 1] if r in {-1, 0, 1} else "?")
if lr is not None:
lr = lr.item()
- sb = sb + thin_vert + +("- +"[lr + 1] if lr in {-1, 0, 1} else "?")
- return sb + " " * (states.size(-1) - len(sb))
-
- result += (
- vert
- + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])])
- + vert
- + "\n"
- )
+ sb_r = sb_r + "/" + ("- +"[lr + 1] if lr in {-1, 0, 1} else "?")
+ return sb_a + " " * (states.size(-1) - len(sb_a) - len(sb_r)) + sb_r
+
+ if lookahead_rewards is None:
+ result += (
+ vert
+ + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])])
+ + vert
+ + "\n"
+ )
+ else:
+ result += (
+ vert
+ + vert.join(
+ [
+ status_bar(a, r, lr)
+ for a, r, lr in zip(
+ actions[n], rewards[n], lookahead_rewards[n]
+ )
+ ]
+ )
+ + vert
+ + "\n"
+ )
result += hline
if __name__ == "__main__":
nb, height, width, T = 8, 4, 6, 20
states, actions, rewards = generate_episodes(nb, height, width, T)
- seq = episodes2seq(states, actions, rewards, lookahead_delta=2)
+ seq = episodes2seq(states, actions, rewards, lookahead_delta=5)
s, a, r, lr = seq2episodes(seq, height, width, lookahead=True)
print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True))