projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
greed.py
diff --git
a/greed.py
b/greed.py
index
636c13b
..
1025d7c
100755
(executable)
--- a/
greed.py
+++ b/
greed.py
@@
-11,6
+11,11
@@
from torch.nn import functional as F
######################################################################
######################################################################
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
class GreedWorld:
def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
class GreedWorld:
def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
@@
-36,11
+41,11
@@
class GreedWorld:
)
self.state_len = self.height * self.width
)
self.state_len = self.height * self.width
- self.index_
states
= 0
- self.index_
reward = self.state_len
- self.index_
lookahead_
reward = self.state_len + 1
+ self.index_
lookahead_reward
= 0
+ self.index_
states = 1
+ self.index_reward = self.state_len + 1
self.index_action = self.state_len + 2
self.index_action = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state /
action / reward
+ self.it_len = self.state_len + 3 # lookahead_reward / state /
reward / action
def state2code(self, r):
return r + self.first_states_code
def state2code(self, r):
return r + self.first_states_code
@@
-179,9
+184,9
@@
class GreedWorld:
return torch.cat(
[
return torch.cat(
[
+ self.lookahead_reward2code(s[:, :, None]),
self.state2code(states.flatten(2)),
self.reward2code(rewards[:, :, None]),
self.state2code(states.flatten(2)),
self.reward2code(rewards[:, :, None]),
- self.lookahead_reward2code(s[:, :, None]),
self.action2code(actions[:, :, None]),
],
dim=2,
self.action2code(actions[:, :, None]),
],
dim=2,