X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=escape.py;h=fc4fbbc97a905e4bb45b120f5f52283a2142afea;hb=9bf6dde1249fde5ba0ca2688599d8dd324d8c503;hp=93e305228a05475b096aba439d0b5625d8b1c6cd;hpb=3168a3161668caacb36ebd717e308e36c9eef2b1;p=picoclvr.git diff --git a/escape.py b/escape.py index 93e3052..fc4fbbc 100755 --- a/escape.py +++ b/escape.py @@ -111,13 +111,12 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): if lookahead_delta is not None: r = rewards - print(f"{r.size()=} {lookahead_delta=}") u = F.pad(r, (0, lookahead_delta - 1)).as_strided( (r.size(0), r.size(1), lookahead_delta), (r.size(1) + lookahead_delta - 1, 1, 1), ) - a = u.min(dim=-1).values - b = u.max(dim=-1).values + a = u[:, :, 1:].min(dim=-1).values + b = u[:, :, 1:].max(dim=-1).values s = (a < 0).long() * a + (a >= 0).long() * b lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code