X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=escape.py;h=f945172d7db242411431a8a4c1b83dcd770a698b;hb=8855d37cef610b39f37d2b3b331046d1e7040a37;hp=93e305228a05475b096aba439d0b5625d8b1c6cd;hpb=fdc61b7e50e029aac58b10f377acdce549532f84;p=picoclvr.git diff --git a/escape.py b/escape.py index 93e3052..f945172 100755 --- a/escape.py +++ b/escape.py @@ -111,13 +111,12 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): if lookahead_delta is not None: r = rewards - print(f"{r.size()=} {lookahead_delta=}") u = F.pad(r, (0, lookahead_delta - 1)).as_strided( (r.size(0), r.size(1), lookahead_delta), (r.size(1) + lookahead_delta - 1, 1, 1), ) - a = u.min(dim=-1).values - b = u.max(dim=-1).values + a = u[:, :, 1:].min(dim=-1).values + b = u[:, :, 1:].max(dim=-1).values s = (a < 0).long() * a + (a >= 0).long() * b lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code @@ -245,8 +244,8 @@ def episodes2str( ###################################################################### if __name__ == "__main__": - nb, height, width, T = 8, 4, 6, 20 + nb, height, width, T = 10, 4, 6, 20 states, actions, rewards = generate_episodes(nb, height, width, T) - seq = episodes2seq(states, actions, rewards, lookahead_delta=5) + seq = episodes2seq(states, actions, rewards, lookahead_delta=T) s, a, r, lr = seq2episodes(seq, height, width, lookahead=True) print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True))