From 819181b36c1af5c3c606b6dcb11a242e9c43331c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 27 Mar 2024 21:35:32 +0100 Subject: [PATCH] Update. --- greed.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/greed.py b/greed.py index 47cfb40..636c13b 100755 --- a/greed.py +++ b/greed.py @@ -172,7 +172,7 @@ class GreedWorld: def episodes2seq(self, states, actions, rewards): neg = rewards.new_zeros(rewards.size()) pos = rewards.new_zeros(rewards.size()) - for t in range(neg.size(1) - 1): + for t in range(neg.size(1)): neg[:, t] = rewards[:, t:].min(dim=-1).values pos[:, t] = rewards[:, t:].max(dim=-1).values s = (neg < 0).long() * neg + (neg >= 0).long() * pos @@ -189,11 +189,15 @@ class GreedWorld: def seq2episodes(self, seq): seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3) - lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0]) - states = self.code2state(seq[:, :, 1 : self.height * self.width + 1]) + lookahead_rewards = self.code2lookahead_reward( + seq[:, :, self.index_lookahead_reward] + ) + states = self.code2state( + seq[:, :, self.index_states : self.height * self.width + self.index_states] + ) states = states.reshape(states.size(0), states.size(1), self.height, self.width) - actions = self.code2action(seq[:, :, self.height * self.width + 1]) - rewards = self.code2reward(seq[:, :, self.height * self.width + 2]) + actions = self.code2action(seq[:, :, self.index_action]) + rewards = self.code2reward(seq[:, :, self.index_reward]) return lookahead_rewards, states, actions, rewards def seq2str(self, seq): -- 2.39.5