From e93880cbc65a7f2a2ee0eacdfccf89c947a32fb0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 26 Mar 2024 09:30:05 +0100 Subject: [PATCH] Update. --- escape.py | 2 +- tasks.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/escape.py b/escape.py index 7596bea..5f34cd1 100755 --- a/escape.py +++ b/escape.py @@ -242,7 +242,7 @@ def episodes2str( def status_bar(a, r, lr=None): a, r = a.item(), r.item() sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?" - sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?" + sb_r = "- +U"[r + 1] if r in {-1, 0, 1, 2} else "?" if lr is None: sb_lr = "" else: diff --git a/tasks.py b/tasks.py index 870ab95..8e8faa9 100755 --- a/tasks.py +++ b/tasks.py @@ -1951,8 +1951,11 @@ class Escape(Task): # Generate iteration after iteration result = self.test_input[:250].clone() + # Erase all the content but that of the first iteration result[:, self.it_len :] = -1 + # Set the lookahead_reward of the firs to UNKNOWN result[:, self.index_lookahead_reward] = escape.lookahead_reward2code(2) + t = torch.arange(result.size(1), device=result.device)[None, :] for u in tqdm.tqdm( -- 2.39.5