X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=escape.py;h=f945172d7db242411431a8a4c1b83dcd770a698b;hb=8855d37cef610b39f37d2b3b331046d1e7040a37;hp=a8b9536bbb5673937f157794424695fc8719b567;hpb=9664af37378218468190741c9ea5c3d7cb231926;p=picoclvr.git diff --git a/escape.py b/escape.py index a8b9536..f945172 100755 --- a/escape.py +++ b/escape.py @@ -111,13 +111,12 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): if lookahead_delta is not None: r = rewards - print(f"{r.size()=} {lookahead_delta=}") u = F.pad(r, (0, lookahead_delta - 1)).as_strided( (r.size(0), r.size(1), lookahead_delta), (r.size(1) + lookahead_delta - 1, 1, 1), ) - a = u.min(dim=-1).values - b = u.max(dim=-1).values + a = u[:, :, 1:].min(dim=-1).values + b = u[:, :, 1:].max(dim=-1).values s = (a < 0).long() * a + (a >= 0).long() * b lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code @@ -204,19 +203,34 @@ def episodes2str( def status_bar(a, r, lr=None): a, r = a.item(), r.item() - sb = "ISNEW"[a] if a >= 0 and a < 5 else "?" - sb = sb + thin_vert + ("- +"[r + 1] if r in {-1, 0, 1} else "?") + sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?" + sb_r = " " + ("- +"[r + 1] if r in {-1, 0, 1} else "?") if lr is not None: lr = lr.item() - sb = sb + thin_vert + +("- +"[lr + 1] if lr in {-1, 0, 1} else "?") - return sb + " " * (states.size(-1) - len(sb)) - - result += ( - vert - + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])]) - + vert - + "\n" - ) + sb_r = sb_r + "/" + ("- +"[lr + 1] if lr in {-1, 0, 1} else "?") + return sb_a + " " * (states.size(-1) - len(sb_a) - len(sb_r)) + sb_r + + if lookahead_rewards is None: + result += ( + vert + + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])]) + + vert + + "\n" + ) + else: + result += ( + vert + + vert.join( + [ + status_bar(a, r, lr) + for a, r, lr in zip( + actions[n], rewards[n], lookahead_rewards[n] + ) + ] + ) + + vert + + "\n" + ) result += hline @@ -230,8 +244,8 @@ def episodes2str( ###################################################################### if __name__ == "__main__": - nb, height, width, T = 8, 4, 6, 20 + nb, height, width, T = 10, 4, 6, 20 states, actions, rewards = generate_episodes(nb, height, width, T) - seq = episodes2seq(states, actions, rewards, lookahead_delta=2) + seq = episodes2seq(states, actions, rewards, lookahead_delta=T) s, a, r, lr = seq2episodes(seq, height, width, lookahead=True) print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True))