######################################################################
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
class GreedWorld:
def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
)
self.state_len = self.height * self.width
- self.index_states = 0
- self.index_reward = self.state_len
- self.index_lookahead_reward = self.state_len + 1
+ self.index_lookahead_reward = 0
+ self.index_states = 1
+ self.index_reward = self.state_len + 1
self.index_action = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
+ self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action
def state2code(self, r):
return r + self.first_states_code
return torch.cat(
[
+ self.lookahead_reward2code(s[:, :, None]),
self.state2code(states.flatten(2)),
self.reward2code(rewards[:, :, None]),
- self.lookahead_reward2code(s[:, :, None]),
self.action2code(actions[:, :, None]),
],
dim=2,