From 30a9e008b687f097b910ad18dd3699033821852e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 2 Nov 2023 17:26:49 +0100 Subject: [PATCH] Update. --- picocrafter.py | 232 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 154 insertions(+), 78 deletions(-) diff --git a/picocrafter.py b/picocrafter.py index 7810b67..ef861bc 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -52,6 +52,29 @@ from torch.nn.functional import conv2d ###################################################################### +def add_ansi_coloring(s): + if type(s) is list: + return [add_ansi_coloring(x) for x in s] + + for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: + s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") + + return s + + +def fusion_multi_lines(l, width_min=0): + l = [x if type(x) is list else [str(x)] for x in l] + + def f(o, h): + w = max(width_min, max([len(r) for r in o])) + return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o] + + h = max([len(x) for x in l]) + l = [f(o, h) for o in l] + + return "\n".join(["|".join([o[k] for o in l]) for k in range(h)]) + + class PicroCrafterEngine: def __init__( self, @@ -79,27 +102,29 @@ class PicroCrafterEngine: self.reward_per_hit = -1 self.reward_death = -10 - self.tokens = " +#@$aAbBcC." - self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)]) - self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)]) + self.tiles = " +#@$aAbBcC-" + "".join( + [str(n) for n in range(self.life_level_max + 1)] + ) + self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)]) + self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)]) self.next_object = dict( [ - (self.token2id[s], self.token2id[t]) + (self.tile2id[s], self.tile2id[t]) for (s, t) in [ ("a", "A"), ("A", "b"), ("b", "B"), ("B", "c"), ("c", "C"), - ("C", "."), + ("C", "-"), ] ] ) self.object_reward = dict( [ - (self.token2id[t], r) + (self.tile2id[t], r) for (t, r) in [ ("a", 0), ("A", 1), @@ -113,7 +138,7 @@ class PicroCrafterEngine: self.accessible_object_to_inventory = dict( [ - (self.token2id[s], self.token2id[t]) + (self.tile2id[s], self.tile2id[t]) for (s, t) in [ ("a", " "), ("A", "a"), @@ -121,7 +146,7 @@ class PicroCrafterEngine: ("B", "b"), ("c", " "), ("C", "c"), - (".", " "), + ("-", " "), ] ] ) @@ -131,10 +156,10 @@ class PicroCrafterEngine: nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin ).to(self.device) self.life_level_in_100th = torch.full( - (nb_agents,), self.life_level_max * 100, device=self.device + (nb_agents,), self.life_level_max * 100 + 99, device=self.device ) self.accessible_object = torch.full( - (nb_agents,), self.token2id["a"], device=self.device + (nb_agents,), self.tile2id["a"], device=self.device ) def create_mazes(self, nb, height, width, nb_walls): @@ -177,15 +202,15 @@ class PicroCrafterEngine: u = torch.rand(q.size(), device=q.device) * (1 - q) r = u.sort(dim=-1, descending=True).indices[:, : len(z)] - q *= self.token2id["#"] + q *= self.tile2id["#"] q[ torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r - ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :] + ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] if margin > 0: r = m.new_full( (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2), - self.token2id["+"], + self.tile2id["+"], ) r[:, margin:-margin, margin:-margin] = m m = r @@ -194,8 +219,14 @@ class PicroCrafterEngine: def nb_actions(self): return 5 - def nb_view_tokens(self): - return len(self.tokens) + def action2str(self, n): + if n >= 0 and n < 5: + return "XNESW"[n] + else: + return "?" + + def nb_view_tiles(self): + return len(self.tiles) def min_max_reward(self): return ( @@ -204,20 +235,20 @@ class PicroCrafterEngine: ) def step(self, actions): - a = (self.worlds == self.token2id["@"]).nonzero() - self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "] + a = (self.worlds == self.tile2id["@"]).nonzero() + self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "] s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device) b = a.clone() b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]] # position is empty - o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long() + o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long() # or it is the next accessible object q = ( self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]] ).long() o = (o + q).clamp(max=1)[:, None] b = (1 - o) * a + o * b - self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"] + self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"] qq = q q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:]) @@ -225,14 +256,14 @@ class PicroCrafterEngine: nb_hits = self.monster_moves() - alive_before = self.life_level_in_100th > 0 + alive_before = self.life_level_in_100th > 99 self.life_level_in_100th[alive_before] = ( self.life_level_in_100th[alive_before] + self.life_level_gain_100th - nb_hits[alive_before] * 100 - ).clamp(max=self.life_level_max * 100) - alive_after = self.life_level_in_100th > 0 - self.worlds[torch.logical_not(alive_after)] = self.token2id["#"] + ).clamp(max=self.life_level_max * 100 + 99) + alive_after = self.life_level_in_100th > 99 + self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"] reward = nb_hits * self.reward_per_hit for i in range(q.size(0)): @@ -243,7 +274,8 @@ class PicroCrafterEngine: ] reward = ( - reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death + alive_after.long() * reward + + alive_before.long() * (1 - alive_after.long()) * self.reward_death ) inventory = torch.tensor( [ @@ -254,7 +286,7 @@ class PicroCrafterEngine: self.life_level_in_100th = ( self.life_level_in_100th - * (self.accessible_object != self.token2id["."]).long() + * (self.accessible_object != self.tile2id["-"]).long() ) reward[torch.logical_not(alive_before)] = 0 @@ -262,7 +294,7 @@ class PicroCrafterEngine: def monster_moves(self): # Current positions of the monsters - m = (self.worlds == self.token2id["$"]).long().flatten(1) + m = (self.worlds == self.tile2id["$"]).long().flatten(1) # Total number of monsters n = m.sum(-1).max() @@ -303,25 +335,25 @@ class PicroCrafterEngine: for n in range(p.size(1)): u = o[:, n].sort(dim=-1, descending=True).indices[:, :1] - q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n] + q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n] r = ( (q * torch.rand(q.size(), device=q.device)) .sort(dim=-1, descending=True) .indices[:, :1] ) - self.worlds.flatten(1)[i, u] = self.token2id[" "] - self.worlds.flatten(1)[i, r] = self.token2id["$"] + self.worlds.flatten(1)[i, u] = self.tile2id[" "] + self.worlds.flatten(1)[i, r] = self.tile2id["$"] nb_hits = ( ( conv2d( - (self.worlds == self.token2id["$"]).float()[:, None], + (self.worlds == self.tile2id["$"]).float()[:, None], move_kernel, padding=1, ) .long() .squeeze(1) - * (self.worlds == self.token2id["@"]).long() + * (self.worlds == self.tile2id["@"]).long() ) .flatten(1) .sum(-1) @@ -334,7 +366,7 @@ class PicroCrafterEngine: self.view_height - 2 * self.margin, self.view_width - 2 * self.margin, ) - a = (self.worlds == self.token2id["@"]).nonzero() + a = (self.worlds == self.tile2id["@"]).nonzero() y = i_height * ((a[:, 1] - self.margin) // i_height) x = i_width * ((a[:, 2] - self.margin) // i_width) n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width) @@ -347,58 +379,89 @@ class PicroCrafterEngine: + x[:, None, None] ).expand_as(n) v = self.worlds.new_full( - (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"] + (self.worlds.size(0), self.view_height + 1, self.view_width), + self.tile2id["#"], ) - v[a[:, 0]] = self.worlds[n, i, j] + v[a[:, 0], : self.view_height] = self.worlds[n, i, j] + + v[:, self.view_height] = self.tile2id["-"] + v[:, self.view_height, 0] = self.tile2id["0"] + ( + self.life_level_in_100th // 100 + ).clamp(min=0, max=self.life_level_max) + v[:, self.view_height, 1] = torch.tensor( + [ + self.accessible_object_to_inventory[o.item()] + for o in self.accessible_object + ], + device=v.device, + ) return v + def seq2tilepic(self, t, width): + def tile(n): + n = n.item() + if n in self.id2tile: + return self.id2tile[n] + else: + return "?" + + if t.dim() == 2: + return [self.seq2tilepic(r, width) for r in t] + + t = t.reshape(-1, width) + + t = ["".join([tile(n) for n in r]) for r in t] + + return t + def print_worlds( self, src=None, comments=[], width=None, printer=print, ansi_term=False ): if src is None: - src = self.worlds + src = list(self.worlds) - if width is None: - width = src.size(2) + height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src]) - def token(n): + def tile(n): n = n.item() - if n in self.id2token: - return self.id2token[n] + if n in self.id2tile: + return self.id2tile[n] else: return "?" - for k in range(src.size(1)): - s = ["".join([token(n) for n in m[k]]) for m in src] - s = [r + " " * (width - len(r)) for r in s] - if ansi_term: - - def colorize(x): - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ - (x, 36) for x in "aAbBcC" - ]: - x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m") - return x + for k in range(height): - s = [colorize(x) for x in s] - printer(" | ".join(s)) + def f(x): + if torch.is_tensor(x): + if x.dim() == 0: + x = str(x.item()) + return " " * len(x) if k < height - 1 else x + else: + s = "".join([tile(n) for n in x[k]]) + if ansi_term: + for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ + (x, 36) for x in "aAbBcC" + ]: + s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") + return s + else: + return " " * len(x) if k < height - 1 else x - s = [c + " " * (width - len(c)) for c in comments] - printer(" | ".join(s)) + printer("|".join([f(x) for x in src])) ###################################################################### if __name__ == "__main__": - import os, time + import os, time, sys device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # nb_agents, nb_iter, display = 10000, 100, False # ansi_term = False - # nb_agents, nb_iter, display = 1000, 1000, False - nb_agents, nb_iter, display = 3, 10000, True + nb_agents, nb_iter, display = 4, 10000, True ansi_term = True start_time = time.perf_counter() @@ -421,35 +484,48 @@ if __name__ == "__main__": start_time = time.perf_counter() + if ansi_term: + coloring = add_ansi_coloring + else: + coloring = lambda x: x + stop = 0 for k in range(nb_iter): + if display: + if ansi_term: + to_print = "\u001bc" + # print("\u001b[2J") + else: + to_print = "" + os.system("clear") + + l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width) + + to_print += coloring(fusion_multi_lines(l)) + "\n\n" + + views = engine.views() action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) - rewards, inventories, life_levels = engine.step( - torch.randint(engine.nb_actions(), (nb_agents,), device=device) - ) + + rewards, inventories, life_levels = engine.step(action) if display: - os.system("clear") - engine.print_worlds( - ansi_term=ansi_term, - ) - print() - engine.print_worlds( - src=engine.views(), - comments=[ - f"L{p}I{engine.id2token[s.item()]}R{r}" - for p, s, r in zip(life_levels, inventories, rewards) - ], - width=engine.world_width, - ansi_term=ansi_term, + l = engine.seq2tilepic(views.flatten(1), engine.view_width) + l = [ + v + [f"{engine.action2str(a.item())}/{r: 3d}"] + for (v, a, r) in zip(l, action, rewards) + ] + + to_print += ( + coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n" ) + + print(to_print) + sys.stdout.flush() time.sleep(0.25) if (life_levels > 0).long().sum() == 0: stop += 1 - if stop == 2: + if stop == 10: break - print( - f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s" - ) + print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s") -- 2.39.5