X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=8e8faa989dea970eb86def7cf8e3887e191db947;hb=e93880cbc65a7f2a2ee0eacdfccf89c947a32fb0;hp=870ab95e913e7597a07494ba40dd595200ae1f4c;hpb=08b58304225e044a21419dd30302d985acc1824c;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 870ab95..8e8faa9 100755 --- a/tasks.py +++ b/tasks.py @@ -1951,8 +1951,11 @@ class Escape(Task): # Generate iteration after iteration result = self.test_input[:250].clone() + # Erase all the content but that of the first iteration result[:, self.it_len :] = -1 + # Set the lookahead_reward of the firs to UNKNOWN result[:, self.index_lookahead_reward] = escape.lookahead_reward2code(2) + t = torch.arange(result.size(1), device=result.device)[None, :] for u in tqdm.tqdm(