Oups
[picoclvr.git] / greed.py
index 47cfb40..1025d7c 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -11,6 +11,11 @@ from torch.nn import functional as F
 
 ######################################################################
 
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
 
 class GreedWorld:
     def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
@@ -36,11 +41,11 @@ class GreedWorld:
         )
 
         self.state_len = self.height * self.width
-        self.index_states = 0
-        self.index_reward = self.state_len
-        self.index_lookahead_reward = self.state_len + 1
+        self.index_lookahead_reward = 0
+        self.index_states = 1
+        self.index_reward = self.state_len + 1
         self.index_action = self.state_len + 2
-        self.it_len = self.state_len + 3  # lookahead_reward / state / action / reward
+        self.it_len = self.state_len + 3  # lookahead_reward / state / reward / action
 
     def state2code(self, r):
         return r + self.first_states_code
@@ -172,16 +177,16 @@ class GreedWorld:
     def episodes2seq(self, states, actions, rewards):
         neg = rewards.new_zeros(rewards.size())
         pos = rewards.new_zeros(rewards.size())
-        for t in range(neg.size(1) - 1):
+        for t in range(neg.size(1)):
             neg[:, t] = rewards[:, t:].min(dim=-1).values
             pos[:, t] = rewards[:, t:].max(dim=-1).values
         s = (neg < 0).long() * neg + (neg >= 0).long() * pos
 
         return torch.cat(
             [
+                self.lookahead_reward2code(s[:, :, None]),
                 self.state2code(states.flatten(2)),
                 self.reward2code(rewards[:, :, None]),
-                self.lookahead_reward2code(s[:, :, None]),
                 self.action2code(actions[:, :, None]),
             ],
             dim=2,
@@ -189,11 +194,15 @@ class GreedWorld:
 
     def seq2episodes(self, seq):
         seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
-        lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
-        states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+        lookahead_rewards = self.code2lookahead_reward(
+            seq[:, :, self.index_lookahead_reward]
+        )
+        states = self.code2state(
+            seq[:, :, self.index_states : self.height * self.width + self.index_states]
+        )
         states = states.reshape(states.size(0), states.size(1), self.height, self.width)
-        actions = self.code2action(seq[:, :, self.height * self.width + 1])
-        rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+        actions = self.code2action(seq[:, :, self.index_action])
+        rewards = self.code2reward(seq[:, :, self.index_reward])
         return lookahead_rewards, states, actions, rewards
 
     def seq2str(self, seq):