From 9bf6dde1249fde5ba0ca2688599d8dd324d8c503 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 24 Mar 2024 11:36:44 +0100 Subject: [PATCH] Update. --- escape.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/escape.py b/escape.py index 93e3052..fc4fbbc 100755 --- a/escape.py +++ b/escape.py @@ -111,13 +111,12 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): if lookahead_delta is not None: r = rewards - print(f"{r.size()=} {lookahead_delta=}") u = F.pad(r, (0, lookahead_delta - 1)).as_strided( (r.size(0), r.size(1), lookahead_delta), (r.size(1) + lookahead_delta - 1, 1, 1), ) - a = u.min(dim=-1).values - b = u.max(dim=-1).values + a = u[:, :, 1:].min(dim=-1).values + b = u[:, :, 1:].max(dim=-1).values s = (a < 0).long() * a + (a >= 0).long() * b lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code -- 2.39.5