######################################################################
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
class GreedWorld:
def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
)
self.state_len = self.height * self.width
- self.index_states = 0
- self.index_reward = self.state_len
- self.index_lookahead_reward = self.state_len + 1
+ self.index_lookahead_reward = 0
+ self.index_states = 1
+ self.index_reward = self.state_len + 1
self.index_action = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
+ self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action
def state2code(self, r):
return r + self.first_states_code
def episodes2seq(self, states, actions, rewards):
neg = rewards.new_zeros(rewards.size())
pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1) - 1):
+ for t in range(neg.size(1)):
neg[:, t] = rewards[:, t:].min(dim=-1).values
pos[:, t] = rewards[:, t:].max(dim=-1).values
s = (neg < 0).long() * neg + (neg >= 0).long() * pos
return torch.cat(
[
+ self.lookahead_reward2code(s[:, :, None]),
self.state2code(states.flatten(2)),
self.reward2code(rewards[:, :, None]),
- self.lookahead_reward2code(s[:, :, None]),
self.action2code(actions[:, :, None]),
],
dim=2,
def seq2episodes(self, seq):
seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
- lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
- states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+ lookahead_rewards = self.code2lookahead_reward(
+ seq[:, :, self.index_lookahead_reward]
+ )
+ states = self.code2state(
+ seq[:, :, self.index_states : self.height * self.width + self.index_states]
+ )
states = states.reshape(states.size(0), states.size(1), self.height, self.width)
- actions = self.code2action(seq[:, :, self.height * self.width + 1])
- rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+ actions = self.code2action(seq[:, :, self.index_action])
+ rewards = self.code2reward(seq[:, :, self.index_reward])
return lookahead_rewards, states, actions, rewards
def seq2str(self, seq):