From 2bb362ef385f477da4af7d8679cc94d42cf6c146 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 25 Mar 2024 17:09:53 +0100 Subject: [PATCH] Update. --- tasks.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tasks.py b/tasks.py index 1d967f9..12c6125 100755 --- a/tasks.py +++ b/tasks.py @@ -1978,23 +1978,19 @@ class Escape(Task): with open(filename, "w") as f: for n in range(10): for s in snapshots: - s, a, r, lr = escape.seq2episodes( + lr, s, a, r = escape.seq2episodes( s[n : n + 1], self.height, self.width, lookahead=True ) str = escape.episodes2str( - s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True + lr, s, a, r, unicode=True, ansi_colors=True ) f.write(str) f.write("\n\n") # Saving the generated sequences - s, a, r, lr = escape.seq2episodes( - result, self.height, self.width, lookahead=True - ) - str = escape.episodes2str( - s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True - ) + s, a, r, lr = escape.seq2episodes(result, self.height, self.width) + str = escape.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: @@ -2009,11 +2005,11 @@ class Escape(Task): # Saving the ground truth s, a, r, lr = escape.seq2episodes( - result, self.height, self.width, lookahead=True - ) - str = escape.episodes2str( - s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True + result, + self.height, + self.width, ) + str = escape.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: @@ -2041,11 +2037,11 @@ class Escape(Task): # Saving the generated sequences s, a, r, lr = escape.seq2episodes( - result, self.height, self.width, lookahead=True - ) - str = escape.episodes2str( - s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True + result, + self.height, + self.width, ) + str = escape.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: -- 2.39.5