def episodes2seq(self, states, actions, rewards):
neg = rewards.new_zeros(rewards.size())
pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1) - 1):
+ for t in range(neg.size(1)):
neg[:, t] = rewards[:, t:].min(dim=-1).values
pos[:, t] = rewards[:, t:].max(dim=-1).values
s = (neg < 0).long() * neg + (neg >= 0).long() * pos
def seq2episodes(self, seq):
seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
- lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
- states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+ lookahead_rewards = self.code2lookahead_reward(
+ seq[:, :, self.index_lookahead_reward]
+ )
+ states = self.code2state(
+ seq[:, :, self.index_states : self.height * self.width + self.index_states]
+ )
states = states.reshape(states.size(0), states.size(1), self.height, self.width)
- actions = self.code2action(seq[:, :, self.height * self.width + 1])
- rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+ actions = self.code2action(seq[:, :, self.index_action])
+ rewards = self.code2reward(seq[:, :, self.index_reward])
return lookahead_rewards, states, actions, rewards
def seq2str(self, seq):