X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=greed.py;h=1025d7c6b324564e108e07e971d3bc2816328ee8;hb=232299b8af7e66a02e64bb2e47b525e2f50b099d;hp=636c13b3cbd09580d65b18b5b701c8510827f732;hpb=819181b36c1af5c3c606b6dcb11a242e9c43331c;p=picoclvr.git diff --git a/greed.py b/greed.py index 636c13b..1025d7c 100755 --- a/greed.py +++ b/greed.py @@ -11,6 +11,11 @@ from torch.nn import functional as F ###################################################################### +REWARD_PLUS = 1 +REWARD_NONE = 0 +REWARD_MINUS = -1 +REWARD_UNKNOWN = 2 + class GreedWorld: def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2): @@ -36,11 +41,11 @@ class GreedWorld: ) self.state_len = self.height * self.width - self.index_states = 0 - self.index_reward = self.state_len - self.index_lookahead_reward = self.state_len + 1 + self.index_lookahead_reward = 0 + self.index_states = 1 + self.index_reward = self.state_len + 1 self.index_action = self.state_len + 2 - self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward + self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action def state2code(self, r): return r + self.first_states_code @@ -179,9 +184,9 @@ class GreedWorld: return torch.cat( [ + self.lookahead_reward2code(s[:, :, None]), self.state2code(states.flatten(2)), self.reward2code(rewards[:, :, None]), - self.lookahead_reward2code(s[:, :, None]), self.action2code(actions[:, :, None]), ], dim=2,