def fusion_multi_lines(l, width_min=0):
- l = [x if type(x) is list else [str(x)] for x in l]
+ l = [x if type(x) is str else str(x) for x in l]
+
+ l = [x.split("\n") for x in l]
def center(r, w):
k = w - len(r)
return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
-class PicroCrafterEngine:
+class PicroCrafterEnvironment:
def __init__(
self,
world_height=27,
else:
return "?"
- def nb_view_tiles(self):
+ def nb_state_token_values(self):
return len(self.tiles)
def min_max_reward(self):
nb_hits = self.monster_moves()
- alive_before = self.life_level_in_100th > 99
+ alive_before = self.life_level_in_100th >= 100
+
self.life_level_in_100th[alive_before] = (
self.life_level_in_100th[alive_before]
+ self.life_level_gain_100th
- nb_hits[alive_before] * 100
).clamp(max=self.life_level_max * 100 + 99)
- alive_after = self.life_level_in_100th > 99
+
+ alive_after = self.life_level_in_100th >= 100
+
self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
+
reward = nb_hits * self.reward_per_hit
for i in range(q.size(0)):
)
reward[torch.logical_not(alive_before)] = 0
+
return reward, inventory, self.life_level_in_100th // 100
def monster_moves(self):
return nb_hits
- def views(self):
+ def state_size(self):
+ return (self.view_height + 1) * self.view_width
+
+ def state(self):
i_height, i_width = (
self.view_height - 2 * self.world_margin,
self.view_width - 2 * self.world_margin,
device=v.device,
)
- return v
+ return v.flatten(1), self.life_level_in_100th >= 100
- def seq2tiles(self, t, width=None):
+ def state2str(self, t, width=None):
def tile(n):
n = n.item()
if n in self.id2tile:
return "?"
if t.dim() == 2:
- return [self.seq2tiles(r, width) for r in t]
+ return [self.state2str(r, width) for r in t]
if width is None:
width = self.view_width
t = t.reshape(-1, width)
- t = ["".join([tile(n) for n in r]) for r in t]
+ t = "\n".join(["".join([tile(n) for n in r]) for r in t])
return t
char_conv = lambda x: to_ansi(to_unicode(x))
start_time = time.perf_counter()
- engine = PicroCrafterEngine(
+ environment = PicroCrafterEnvironment(
world_height=27,
world_width=27,
nb_walls=35,
device=device,
)
- engine.reset(nb_agents)
+ environment.reset(nb_agents)
print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
to_print = ""
os.system("clear")
- l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width)
+ l = environment.state2str(
+ environment.worlds.flatten(1), width=environment.world_width
+ )
to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
- views = engine.views()
- action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+ state, alive = environment.state()
+ action = alive * torch.randint(
+ environment.nb_actions(), (nb_agents,), device=device
+ )
- rewards, inventories, life_levels = engine.step(action)
+ rewards, inventories, life_levels = environment.step(action)
if display:
- l = engine.seq2tiles(views.flatten(1))
+ l = environment.state2str(state)
l = [
- v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+ v + f"\n{environment.action2str(a.item())}/{r: 3d}"
for (v, a, r) in zip(l, action, rewards)
]
to_print += (
- char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+ char_conv(fusion_multi_lines(l, width_min=environment.world_width))
+ + "\n"
)
print(to_print)