######################################################################
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
class GreedWorld:
def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
)
self.state_len = self.height * self.width
- self.index_states = 0
- self.index_reward = self.state_len
- self.index_lookahead_reward = self.state_len + 1
+ self.index_lookahead_reward = 0
+ self.index_states = 1
+ self.index_reward = self.state_len + 1
self.index_action = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
+ self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action
def state2code(self, r):
return r + self.first_states_code
return torch.cat(
[
+ self.lookahead_reward2code(s[:, :, None]),
self.state2code(states.flatten(2)),
self.reward2code(rewards[:, :, None]),
- self.lookahead_reward2code(s[:, :, None]),
self.action2code(actions[:, :, None]),
],
dim=2,
t % self.world.it_len == self.world.index_lookahead_reward
).long()
- return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch
+ return (
+ lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
+ + (1 - lr_mask) * batch
+ )
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
result[:, self.world.it_len :] = -1
# Set the lookahead_reward of the firs to UNKNOWN
result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
- 2
+ greed.REWARD_UNKNOWN
)
t = torch.arange(result.size(1), device=result.device)[None, :]
if u > 0:
result[
:, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(2)
+ ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
ar_mask = (t >= u + self.world.index_states).long() * (
t < u + self.world.index_states + self.world.state_len
).long()
# Generate the action and reward with lookahead_reward to +1
result[
:, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(1)
+ ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
ar_mask = (t >= u + self.world.index_reward).long() * (
t <= u + self.world.index_action
).long()
# Set the lookahead_reward to UNKNOWN for the next iterations
result[
:, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(2)
+ ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN)
filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
with open(filename, "w") as f: