Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 27 Mar 2024 20:35:32 +0000 (21:35 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 27 Mar 2024 20:35:32 +0000 (21:35 +0100)
greed.py

index 47cfb40..636c13b 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -172,7 +172,7 @@ class GreedWorld:
     def episodes2seq(self, states, actions, rewards):
         neg = rewards.new_zeros(rewards.size())
         pos = rewards.new_zeros(rewards.size())
-        for t in range(neg.size(1) - 1):
+        for t in range(neg.size(1)):
             neg[:, t] = rewards[:, t:].min(dim=-1).values
             pos[:, t] = rewards[:, t:].max(dim=-1).values
         s = (neg < 0).long() * neg + (neg >= 0).long() * pos
@@ -189,11 +189,15 @@ class GreedWorld:
 
     def seq2episodes(self, seq):
         seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
-        lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
-        states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+        lookahead_rewards = self.code2lookahead_reward(
+            seq[:, :, self.index_lookahead_reward]
+        )
+        states = self.code2state(
+            seq[:, :, self.index_states : self.height * self.width + self.index_states]
+        )
         states = states.reshape(states.size(0), states.size(1), self.height, self.width)
-        actions = self.code2action(seq[:, :, self.height * self.width + 1])
-        rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+        actions = self.code2action(seq[:, :, self.index_action])
+        rewards = self.code2reward(seq[:, :, self.index_reward])
         return lookahead_rewards, states, actions, rewards
 
     def seq2str(self, seq):