projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
d507a6f
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sun, 12 Nov 2023 07:14:52 +0000
(08:14 +0100)
committer
François Fleuret
<francois@fleuret.org>
Sun, 12 Nov 2023 07:14:52 +0000
(08:14 +0100)
picocrafter.py
patch
|
blob
|
history
diff --git
a/picocrafter.py
b/picocrafter.py
index
36088ac
..
5bd6a48
100755
(executable)
--- a/
picocrafter.py
+++ b/
picocrafter.py
@@
-74,7
+74,9
@@
def to_unicode(s):
def fusion_multi_lines(l, width_min=0):
def fusion_multi_lines(l, width_min=0):
- l = [x if type(x) is list else [str(x)] for x in l]
+ l = [x if type(x) is str else str(x) for x in l]
+
+ l = [x.split("\n") for x in l]
def center(r, w):
k = w - len(r)
def center(r, w):
k = w - len(r)
@@
-90,7
+92,7
@@
def fusion_multi_lines(l, width_min=0):
return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
-class PicroCrafterEn
gine
:
+class PicroCrafterEn
vironment
:
def __init__(
self,
world_height=27,
def __init__(
self,
world_height=27,
@@
-246,7
+248,7
@@
class PicroCrafterEngine:
else:
return "?"
else:
return "?"
- def nb_
view_til
es(self):
+ def nb_
state_token_valu
es(self):
return len(self.tiles)
def min_max_reward(self):
return len(self.tiles)
def min_max_reward(self):
@@
-277,14
+279,18
@@
class PicroCrafterEngine:
nb_hits = self.monster_moves()
nb_hits = self.monster_moves()
- alive_before = self.life_level_in_100th > 99
+ alive_before = self.life_level_in_100th >= 100
+
self.life_level_in_100th[alive_before] = (
self.life_level_in_100th[alive_before]
+ self.life_level_gain_100th
- nb_hits[alive_before] * 100
).clamp(max=self.life_level_max * 100 + 99)
self.life_level_in_100th[alive_before] = (
self.life_level_in_100th[alive_before]
+ self.life_level_gain_100th
- nb_hits[alive_before] * 100
).clamp(max=self.life_level_max * 100 + 99)
- alive_after = self.life_level_in_100th > 99
+
+ alive_after = self.life_level_in_100th >= 100
+
self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
+
reward = nb_hits * self.reward_per_hit
for i in range(q.size(0)):
reward = nb_hits * self.reward_per_hit
for i in range(q.size(0)):
@@
-311,6
+317,7
@@
class PicroCrafterEngine:
)
reward[torch.logical_not(alive_before)] = 0
)
reward[torch.logical_not(alive_before)] = 0
+
return reward, inventory, self.life_level_in_100th // 100
def monster_moves(self):
return reward, inventory, self.life_level_in_100th // 100
def monster_moves(self):
@@
-382,7
+389,10
@@
class PicroCrafterEngine:
return nb_hits
return nb_hits
- def views(self):
+ def state_size(self):
+ return (self.view_height + 1) * self.view_width
+
+ def state(self):
i_height, i_width = (
self.view_height - 2 * self.world_margin,
self.view_width - 2 * self.world_margin,
i_height, i_width = (
self.view_height - 2 * self.world_margin,
self.view_width - 2 * self.world_margin,
@@
-418,9
+428,9
@@
class PicroCrafterEngine:
device=v.device,
)
device=v.device,
)
- return v
+ return v
.flatten(1), self.life_level_in_100th >= 100
- def s
eq2tiles
(self, t, width=None):
+ def s
tate2str
(self, t, width=None):
def tile(n):
n = n.item()
if n in self.id2tile:
def tile(n):
n = n.item()
if n in self.id2tile:
@@
-429,14
+439,14
@@
class PicroCrafterEngine:
return "?"
if t.dim() == 2:
return "?"
if t.dim() == 2:
- return [self.s
eq2tiles
(r, width) for r in t]
+ return [self.s
tate2str
(r, width) for r in t]
if width is None:
width = self.view_width
t = t.reshape(-1, width)
if width is None:
width = self.view_width
t = t.reshape(-1, width)
- t =
["".join([tile(n) for n in r]) for r in t]
+ t =
"\n".join(["".join([tile(n) for n in r]) for r in t])
return t
return t
@@
-461,7
+471,7
@@
if __name__ == "__main__":
char_conv = lambda x: to_ansi(to_unicode(x))
start_time = time.perf_counter()
char_conv = lambda x: to_ansi(to_unicode(x))
start_time = time.perf_counter()
- en
gine = PicroCrafterEngine
(
+ en
vironment = PicroCrafterEnvironment
(
world_height=27,
world_width=27,
nb_walls=35,
world_height=27,
world_width=27,
nb_walls=35,
@@
-471,7
+481,7
@@
if __name__ == "__main__":
device=device,
)
device=device,
)
- en
gine
.reset(nb_agents)
+ en
vironment
.reset(nb_agents)
print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
@@
-487,24
+497,29
@@
if __name__ == "__main__":
to_print = ""
os.system("clear")
to_print = ""
os.system("clear")
- l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width)
+ l = environment.state2str(
+ environment.worlds.flatten(1), width=environment.world_width
+ )
to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
- views = engine.views()
- action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+ state, alive = environment.state()
+ action = alive * torch.randint(
+ environment.nb_actions(), (nb_agents,), device=device
+ )
- rewards, inventories, life_levels = en
gine
.step(action)
+ rewards, inventories, life_levels = en
vironment
.step(action)
if display:
if display:
- l = en
gine.seq2tiles(views.flatten(1)
)
+ l = en
vironment.state2str(state
)
l = [
l = [
- v +
[f"{engine.action2str(a.item())}/{r: 3d}"]
+ v +
f"\n{environment.action2str(a.item())}/{r: 3d}"
for (v, a, r) in zip(l, action, rewards)
]
to_print += (
for (v, a, r) in zip(l, action, rewards)
]
to_print += (
- char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+ char_conv(fusion_multi_lines(l, width_min=environment.world_width))
+ + "\n"
)
print(to_print)
)
print(to_print)