From 232299b8af7e66a02e64bb2e47b525e2f50b099d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 28 Mar 2024 08:18:01 +0100 Subject: [PATCH] Update. --- greed.py | 15 ++++++++++----- tasks.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/greed.py b/greed.py index 636c13b..1025d7c 100755 --- a/greed.py +++ b/greed.py @@ -11,6 +11,11 @@ from torch.nn import functional as F ###################################################################### +REWARD_PLUS = 1 +REWARD_NONE = 0 +REWARD_MINUS = -1 +REWARD_UNKNOWN = 2 + class GreedWorld: def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2): @@ -36,11 +41,11 @@ class GreedWorld: ) self.state_len = self.height * self.width - self.index_states = 0 - self.index_reward = self.state_len - self.index_lookahead_reward = self.state_len + 1 + self.index_lookahead_reward = 0 + self.index_states = 1 + self.index_reward = self.state_len + 1 self.index_action = self.state_len + 2 - self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward + self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action def state2code(self, r): return r + self.first_states_code @@ -179,9 +184,9 @@ class GreedWorld: return torch.cat( [ + self.lookahead_reward2code(s[:, :, None]), self.state2code(states.flatten(2)), self.reward2code(rewards[:, :, None]), - self.lookahead_reward2code(s[:, :, None]), self.action2code(actions[:, :, None]), ], dim=2, diff --git a/tasks.py b/tasks.py index 6a7e639..324376d 100755 --- a/tasks.py +++ b/tasks.py @@ -1905,7 +1905,10 @@ class Greed(Task): t % self.world.it_len == self.world.index_lookahead_reward ).long() - return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch + return ( + lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) + + (1 - lr_mask) * batch + ) def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1950,7 +1953,7 @@ class Greed(Task): result[:, self.world.it_len :] = -1 # Set the lookahead_reward of the firs to UNKNOWN result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code( - 2 + greed.REWARD_UNKNOWN ) t = torch.arange(result.size(1), device=result.device)[None, :] @@ -1965,7 +1968,7 @@ class Greed(Task): if u > 0: result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(2) + ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) ar_mask = (t >= u + self.world.index_states).long() * ( t < u + self.world.index_states + self.world.state_len ).long() @@ -1974,7 +1977,7 @@ class Greed(Task): # Generate the action and reward with lookahead_reward to +1 result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(1) + ] = self.world.lookahead_reward2code(greed.REWARD_PLUS) ar_mask = (t >= u + self.world.index_reward).long() * ( t <= u + self.world.index_action ).long() @@ -1983,7 +1986,7 @@ class Greed(Task): # Set the lookahead_reward to UNKNOWN for the next iterations result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(2) + ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: -- 2.39.5