X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=greed.py;h=1025d7c6b324564e108e07e971d3bc2816328ee8;hb=232299b8af7e66a02e64bb2e47b525e2f50b099d;hp=47cfb40bb00d638a8a9a8d38f69e8afe8fc7c55f;hpb=41164ce7ce1d071a4eb71f72ff277933794cf316;p=picoclvr.git diff --git a/greed.py b/greed.py index 47cfb40..1025d7c 100755 --- a/greed.py +++ b/greed.py @@ -11,6 +11,11 @@ from torch.nn import functional as F ###################################################################### +REWARD_PLUS = 1 +REWARD_NONE = 0 +REWARD_MINUS = -1 +REWARD_UNKNOWN = 2 + class GreedWorld: def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2): @@ -36,11 +41,11 @@ class GreedWorld: ) self.state_len = self.height * self.width - self.index_states = 0 - self.index_reward = self.state_len - self.index_lookahead_reward = self.state_len + 1 + self.index_lookahead_reward = 0 + self.index_states = 1 + self.index_reward = self.state_len + 1 self.index_action = self.state_len + 2 - self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward + self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action def state2code(self, r): return r + self.first_states_code @@ -172,16 +177,16 @@ class GreedWorld: def episodes2seq(self, states, actions, rewards): neg = rewards.new_zeros(rewards.size()) pos = rewards.new_zeros(rewards.size()) - for t in range(neg.size(1) - 1): + for t in range(neg.size(1)): neg[:, t] = rewards[:, t:].min(dim=-1).values pos[:, t] = rewards[:, t:].max(dim=-1).values s = (neg < 0).long() * neg + (neg >= 0).long() * pos return torch.cat( [ + self.lookahead_reward2code(s[:, :, None]), self.state2code(states.flatten(2)), self.reward2code(rewards[:, :, None]), - self.lookahead_reward2code(s[:, :, None]), self.action2code(actions[:, :, None]), ], dim=2, @@ -189,11 +194,15 @@ class GreedWorld: def seq2episodes(self, seq): seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3) - lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0]) - states = self.code2state(seq[:, :, 1 : self.height * self.width + 1]) + lookahead_rewards = self.code2lookahead_reward( + seq[:, :, self.index_lookahead_reward] + ) + states = self.code2state( + seq[:, :, self.index_states : self.height * self.width + self.index_states] + ) states = states.reshape(states.size(0), states.size(1), self.height, self.width) - actions = self.code2action(seq[:, :, self.height * self.width + 1]) - rewards = self.code2reward(seq[:, :, self.height * self.width + 2]) + actions = self.code2action(seq[:, :, self.index_action]) + rewards = self.code2reward(seq[:, :, self.index_reward]) return lookahead_rewards, states, actions, rewards def seq2str(self, seq):