######################################################################
+def add_ansi_coloring(s):
+ if type(s) is list:
+ return [add_ansi_coloring(x) for x in s]
+
+ for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
+ s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+
+ return s
+
+
+def fusion_multi_lines(l, width_min=0):
+ l = [x if type(x) is list else [str(x)] for x in l]
+
+ def f(o, h):
+ w = max(width_min, max([len(r) for r in o]))
+ return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o]
+
+ h = max([len(x) for x in l])
+ l = [f(o, h) for o in l]
+
+ return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
+
+
class PicroCrafterEngine:
def __init__(
self,
self.reward_per_hit = -1
self.reward_death = -10
- self.tokens = " +#@$aAbBcC."
- self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
- self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
+ self.tiles = " +#@$aAbBcC-" + "".join(
+ [str(n) for n in range(self.life_level_max + 1)]
+ )
+ self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)])
+ self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)])
self.next_object = dict(
[
- (self.token2id[s], self.token2id[t])
+ (self.tile2id[s], self.tile2id[t])
for (s, t) in [
("a", "A"),
("A", "b"),
("b", "B"),
("B", "c"),
("c", "C"),
- ("C", "."),
+ ("C", "-"),
]
]
)
self.object_reward = dict(
[
- (self.token2id[t], r)
+ (self.tile2id[t], r)
for (t, r) in [
("a", 0),
("A", 1),
self.accessible_object_to_inventory = dict(
[
- (self.token2id[s], self.token2id[t])
+ (self.tile2id[s], self.tile2id[t])
for (s, t) in [
("a", " "),
("A", "a"),
("B", "b"),
("c", " "),
("C", "c"),
- (".", " "),
+ ("-", " "),
]
]
)
nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
).to(self.device)
self.life_level_in_100th = torch.full(
- (nb_agents,), self.life_level_max * 100, device=self.device
+ (nb_agents,), self.life_level_max * 100 + 99, device=self.device
)
self.accessible_object = torch.full(
- (nb_agents,), self.token2id["a"], device=self.device
+ (nb_agents,), self.tile2id["a"], device=self.device
)
def create_mazes(self, nb, height, width, nb_walls):
u = torch.rand(q.size(), device=q.device) * (1 - q)
r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
- q *= self.token2id["#"]
+ q *= self.tile2id["#"]
q[
torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
- ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :]
+ ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
if margin > 0:
r = m.new_full(
(m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
- self.token2id["+"],
+ self.tile2id["+"],
)
r[:, margin:-margin, margin:-margin] = m
m = r
def nb_actions(self):
return 5
- def nb_view_tokens(self):
- return len(self.tokens)
+ def action2str(self, n):
+ if n >= 0 and n < 5:
+ return "XNESW"[n]
+ else:
+ return "?"
+
+ def nb_view_tiles(self):
+ return len(self.tiles)
def min_max_reward(self):
return (
)
def step(self, actions):
- a = (self.worlds == self.token2id["@"]).nonzero()
- self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "]
+ a = (self.worlds == self.tile2id["@"]).nonzero()
+ self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "]
s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
b = a.clone()
b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
# position is empty
- o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
+ o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long()
# or it is the next accessible object
q = (
self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
).long()
o = (o + q).clamp(max=1)[:, None]
b = (1 - o) * a + o * b
- self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
+ self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"]
qq = q
q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
nb_hits = self.monster_moves()
- alive_before = self.life_level_in_100th > 0
+ alive_before = self.life_level_in_100th > 99
self.life_level_in_100th[alive_before] = (
self.life_level_in_100th[alive_before]
+ self.life_level_gain_100th
- nb_hits[alive_before] * 100
- ).clamp(max=self.life_level_max * 100)
- alive_after = self.life_level_in_100th > 0
- self.worlds[torch.logical_not(alive_after)] = self.token2id["#"]
+ ).clamp(max=self.life_level_max * 100 + 99)
+ alive_after = self.life_level_in_100th > 99
+ self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
reward = nb_hits * self.reward_per_hit
for i in range(q.size(0)):
]
reward = (
- reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death
+ alive_after.long() * reward
+ + alive_before.long() * (1 - alive_after.long()) * self.reward_death
)
inventory = torch.tensor(
[
self.life_level_in_100th = (
self.life_level_in_100th
- * (self.accessible_object != self.token2id["."]).long()
+ * (self.accessible_object != self.tile2id["-"]).long()
)
reward[torch.logical_not(alive_before)] = 0
def monster_moves(self):
# Current positions of the monsters
- m = (self.worlds == self.token2id["$"]).long().flatten(1)
+ m = (self.worlds == self.tile2id["$"]).long().flatten(1)
# Total number of monsters
n = m.sum(-1).max()
for n in range(p.size(1)):
u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
- q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n]
+ q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n]
r = (
(q * torch.rand(q.size(), device=q.device))
.sort(dim=-1, descending=True)
.indices[:, :1]
)
- self.worlds.flatten(1)[i, u] = self.token2id[" "]
- self.worlds.flatten(1)[i, r] = self.token2id["$"]
+ self.worlds.flatten(1)[i, u] = self.tile2id[" "]
+ self.worlds.flatten(1)[i, r] = self.tile2id["$"]
nb_hits = (
(
conv2d(
- (self.worlds == self.token2id["$"]).float()[:, None],
+ (self.worlds == self.tile2id["$"]).float()[:, None],
move_kernel,
padding=1,
)
.long()
.squeeze(1)
- * (self.worlds == self.token2id["@"]).long()
+ * (self.worlds == self.tile2id["@"]).long()
)
.flatten(1)
.sum(-1)
self.view_height - 2 * self.margin,
self.view_width - 2 * self.margin,
)
- a = (self.worlds == self.token2id["@"]).nonzero()
+ a = (self.worlds == self.tile2id["@"]).nonzero()
y = i_height * ((a[:, 1] - self.margin) // i_height)
x = i_width * ((a[:, 2] - self.margin) // i_width)
n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
+ x[:, None, None]
).expand_as(n)
v = self.worlds.new_full(
- (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"]
+ (self.worlds.size(0), self.view_height + 1, self.view_width),
+ self.tile2id["#"],
)
- v[a[:, 0]] = self.worlds[n, i, j]
+ v[a[:, 0], : self.view_height] = self.worlds[n, i, j]
+
+ v[:, self.view_height] = self.tile2id["-"]
+ v[:, self.view_height, 0] = self.tile2id["0"] + (
+ self.life_level_in_100th // 100
+ ).clamp(min=0, max=self.life_level_max)
+ v[:, self.view_height, 1] = torch.tensor(
+ [
+ self.accessible_object_to_inventory[o.item()]
+ for o in self.accessible_object
+ ],
+ device=v.device,
+ )
return v
+ def seq2tilepic(self, t, width):
+ def tile(n):
+ n = n.item()
+ if n in self.id2tile:
+ return self.id2tile[n]
+ else:
+ return "?"
+
+ if t.dim() == 2:
+ return [self.seq2tilepic(r, width) for r in t]
+
+ t = t.reshape(-1, width)
+
+ t = ["".join([tile(n) for n in r]) for r in t]
+
+ return t
+
def print_worlds(
self, src=None, comments=[], width=None, printer=print, ansi_term=False
):
if src is None:
- src = self.worlds
+ src = list(self.worlds)
- if width is None:
- width = src.size(2)
+ height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
- def token(n):
+ def tile(n):
n = n.item()
- if n in self.id2token:
- return self.id2token[n]
+ if n in self.id2tile:
+ return self.id2tile[n]
else:
return "?"
- for k in range(src.size(1)):
- s = ["".join([token(n) for n in m[k]]) for m in src]
- s = [r + " " * (width - len(r)) for r in s]
- if ansi_term:
-
- def colorize(x):
- for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
- (x, 36) for x in "aAbBcC"
- ]:
- x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m")
- return x
+ for k in range(height):
- s = [colorize(x) for x in s]
- printer(" | ".join(s))
+ def f(x):
+ if torch.is_tensor(x):
+ if x.dim() == 0:
+ x = str(x.item())
+ return " " * len(x) if k < height - 1 else x
+ else:
+ s = "".join([tile(n) for n in x[k]])
+ if ansi_term:
+ for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
+ (x, 36) for x in "aAbBcC"
+ ]:
+ s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+ return s
+ else:
+ return " " * len(x) if k < height - 1 else x
- s = [c + " " * (width - len(c)) for c in comments]
- printer(" | ".join(s))
+ printer("|".join([f(x) for x in src]))
######################################################################
if __name__ == "__main__":
- import os, time
+ import os, time, sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # nb_agents, nb_iter, display = 10000, 100, False
# ansi_term = False
- # nb_agents, nb_iter, display = 1000, 1000, False
- nb_agents, nb_iter, display = 3, 10000, True
+ nb_agents, nb_iter, display = 4, 10000, True
ansi_term = True
start_time = time.perf_counter()
start_time = time.perf_counter()
+ if ansi_term:
+ coloring = add_ansi_coloring
+ else:
+ coloring = lambda x: x
+
stop = 0
for k in range(nb_iter):
+ if display:
+ if ansi_term:
+ to_print = "\u001bc"
+ # print("\u001b[2J")
+ else:
+ to_print = ""
+ os.system("clear")
+
+ l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
+
+ to_print += coloring(fusion_multi_lines(l)) + "\n\n"
+
+ views = engine.views()
action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
- rewards, inventories, life_levels = engine.step(
- torch.randint(engine.nb_actions(), (nb_agents,), device=device)
- )
+
+ rewards, inventories, life_levels = engine.step(action)
if display:
- os.system("clear")
- engine.print_worlds(
- ansi_term=ansi_term,
- )
- print()
- engine.print_worlds(
- src=engine.views(),
- comments=[
- f"L{p}I{engine.id2token[s.item()]}R{r}"
- for p, s, r in zip(life_levels, inventories, rewards)
- ],
- width=engine.world_width,
- ansi_term=ansi_term,
+ l = engine.seq2tilepic(views.flatten(1), engine.view_width)
+ l = [
+ v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+ for (v, a, r) in zip(l, action, rewards)
+ ]
+
+ to_print += (
+ coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
)
+
+ print(to_print)
+ sys.stdout.flush()
time.sleep(0.25)
if (life_levels > 0).long().sum() == 0:
stop += 1
- if stop == 2:
+ if stop == 10:
break
- print(
- f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"
- )
+ print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")