X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picocrafter.py;h=5bd6a48e1d8a0906fb03b183c0b83b76a879fdd0;hb=679c210fdde56edfaf2a1859881139e9d6e81664;hp=ef861bcfbf92440dd92934031e91a0c0d7abc988;hpb=30a9e008b687f097b910ad18dd3699033821852e;p=pytorch.git diff --git a/picocrafter.py b/picocrafter.py index ef861bc..5bd6a48 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -22,8 +22,9 @@ # it can initialize ~20k environments per second and run ~40k # iterations. # -# The agent "@" moves in a maze-like grid with random walls "#". There -# are five actions: move NESW or do not move. +# The environment is a rectangular area with walls "#" dispatched +# randomly. The agent "@" can perform five actions: move NESW or do +# not move. # # There are monsters "$" moving randomly. The agent gets hit by every # monster present in one of the 4 direct neighborhoods at the end of @@ -40,8 +41,8 @@ # which case the key is removed from the environment and the agent now # carries it, and can move to free spaces or the "A". When it moves to # the "A", it gets a reward, loses the "a", the "A" is removed from -# the environment, but can now move to the "b", etc. Rewards are 1 for -# "A" and "B" and 10 for "C". +# the environment, but the agent can now move to the "b", etc. Rewards +# are 1 for "A" and "B" and 10 for "C". ###################################################################### @@ -52,22 +53,38 @@ from torch.nn.functional import conv2d ###################################################################### -def add_ansi_coloring(s): +def to_ansi(s): if type(s) is list: - return [add_ansi_coloring(x) for x in s] + return [to_ansi(x) for x in s] - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: + for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") return s +def to_unicode(s): + if type(s) is list: + return [to_unicode(x) for x in s] + + for u, c in [("#", "█"), ("+", "░"), ("|", "│")]: + s = s.replace(u, c) + + return s + + def fusion_multi_lines(l, width_min=0): - l = [x if type(x) is list else [str(x)] for x in l] + l = [x if type(x) is str else str(x) for x in l] + + l = [x.split("\n") for x in l] + + def center(r, w): + k = w - len(r) + return " " * (k // 2) + r + " " * (k - k // 2) def f(o, h): w = max(width_min, max([len(r) for r in o])) - return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o] + return [" " * w] * (h - len(o)) + [center(r, w) for r in o] h = max([len(x) for x in l]) l = [f(o, h) for o in l] @@ -75,25 +92,25 @@ def fusion_multi_lines(l, width_min=0): return "\n".join(["|".join([o[k] for o in l]) for k in range(h)]) -class PicroCrafterEngine: +class PicroCrafterEnvironment: def __init__( self, world_height=27, world_width=27, nb_walls=27, - margin=2, + world_margin=2, view_height=5, view_width=5, device=torch.device("cpu"), ): - assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0 - assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0 + assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0 + assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0 self.device = device self.world_height = world_height self.world_width = world_width - self.margin = margin + self.world_margin = world_margin self.view_height = view_height self.view_width = view_width self.nb_walls = nb_walls @@ -153,7 +170,11 @@ class PicroCrafterEngine: def reset(self, nb_agents): self.worlds = self.create_worlds( - nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin + nb_agents, + self.world_height, + self.world_width, + self.nb_walls, + self.world_margin, ).to(self.device) self.life_level_in_100th = torch.full( (nb_agents,), self.life_level_max * 100 + 99, device=self.device @@ -194,9 +215,11 @@ class PicroCrafterEngine: return m - def create_worlds(self, nb, height, width, nb_walls, margin=2): - margin -= 1 # The maze adds a wall all around - m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls) + def create_worlds(self, nb, height, width, nb_walls, world_margin=2): + world_margin -= 1 # The maze adds a wall all around + m = self.create_mazes( + nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls + ) q = m.flatten(1) z = "@aAbBcC$$$$$" # What to add to the maze u = torch.rand(q.size(), device=q.device) * (1 - q) @@ -207,12 +230,12 @@ class PicroCrafterEngine: torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] - if margin > 0: + if world_margin > 0: r = m.new_full( - (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2), + (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2), self.tile2id["+"], ) - r[:, margin:-margin, margin:-margin] = m + r[:, world_margin:-world_margin, world_margin:-world_margin] = m m = r return m @@ -225,7 +248,7 @@ class PicroCrafterEngine: else: return "?" - def nb_view_tiles(self): + def nb_state_token_values(self): return len(self.tiles) def min_max_reward(self): @@ -256,14 +279,18 @@ class PicroCrafterEngine: nb_hits = self.monster_moves() - alive_before = self.life_level_in_100th > 99 + alive_before = self.life_level_in_100th >= 100 + self.life_level_in_100th[alive_before] = ( self.life_level_in_100th[alive_before] + self.life_level_gain_100th - nb_hits[alive_before] * 100 ).clamp(max=self.life_level_max * 100 + 99) - alive_after = self.life_level_in_100th > 99 + + alive_after = self.life_level_in_100th >= 100 + self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"] + reward = nb_hits * self.reward_per_hit for i in range(q.size(0)): @@ -290,6 +317,7 @@ class PicroCrafterEngine: ) reward[torch.logical_not(alive_before)] = 0 + return reward, inventory, self.life_level_in_100th // 100 def monster_moves(self): @@ -361,14 +389,17 @@ class PicroCrafterEngine: return nb_hits - def views(self): + def state_size(self): + return (self.view_height + 1) * self.view_width + + def state(self): i_height, i_width = ( - self.view_height - 2 * self.margin, - self.view_width - 2 * self.margin, + self.view_height - 2 * self.world_margin, + self.view_width - 2 * self.world_margin, ) a = (self.worlds == self.tile2id["@"]).nonzero() - y = i_height * ((a[:, 1] - self.margin) // i_height) - x = i_width * ((a[:, 2] - self.margin) // i_width) + y = i_height * ((a[:, 1] - self.world_margin) // i_height) + x = i_width * ((a[:, 2] - self.world_margin) // i_width) n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width) i = ( torch.arange(self.view_height, device=a.device)[None, :, None] @@ -397,9 +428,9 @@ class PicroCrafterEngine: device=v.device, ) - return v + return v.flatten(1), self.life_level_in_100th >= 100 - def seq2tilepic(self, t, width): + def state2str(self, t, width=None): def tile(n): n = n.item() if n in self.id2tile: @@ -408,49 +439,17 @@ class PicroCrafterEngine: return "?" if t.dim() == 2: - return [self.seq2tilepic(r, width) for r in t] + return [self.state2str(r, width) for r in t] + + if width is None: + width = self.view_width t = t.reshape(-1, width) - t = ["".join([tile(n) for n in r]) for r in t] + t = "\n".join(["".join([tile(n) for n in r]) for r in t]) return t - def print_worlds( - self, src=None, comments=[], width=None, printer=print, ansi_term=False - ): - if src is None: - src = list(self.worlds) - - height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src]) - - def tile(n): - n = n.item() - if n in self.id2tile: - return self.id2tile[n] - else: - return "?" - - for k in range(height): - - def f(x): - if torch.is_tensor(x): - if x.dim() == 0: - x = str(x.item()) - return " " * len(x) if k < height - 1 else x - else: - s = "".join([tile(n) for n in x[k]]) - if ansi_term: - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ - (x, 36) for x in "aAbBcC" - ]: - s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") - return s - else: - return " " * len(x) if k < height - 1 else x - - printer("|".join([f(x) for x in src])) - ###################################################################### @@ -459,36 +458,35 @@ if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # nb_agents, nb_iter, display = 10000, 100, False + # char_conv = lambda x: x + char_conv = to_unicode + + # nb_agents, nb_iter, display = 1000, 1000, False # ansi_term = False + nb_agents, nb_iter, display = 4, 10000, True ansi_term = True + if ansi_term: + char_conv = lambda x: to_ansi(to_unicode(x)) + start_time = time.perf_counter() - engine = PicroCrafterEngine( + environment = PicroCrafterEnvironment( world_height=27, world_width=27, nb_walls=35, - # world_height=15, - # world_width=15, - # nb_walls=0, view_height=9, view_width=9, - margin=4, + world_margin=4, device=device, ) - engine.reset(nb_agents) + environment.reset(nb_agents) print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s") start_time = time.perf_counter() - if ansi_term: - coloring = add_ansi_coloring - else: - coloring = lambda x: x - stop = 0 for k in range(nb_iter): if display: @@ -499,24 +497,29 @@ if __name__ == "__main__": to_print = "" os.system("clear") - l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width) + l = environment.state2str( + environment.worlds.flatten(1), width=environment.world_width + ) - to_print += coloring(fusion_multi_lines(l)) + "\n\n" + to_print += char_conv(fusion_multi_lines(l)) + "\n\n" - views = engine.views() - action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) + state, alive = environment.state() + action = alive * torch.randint( + environment.nb_actions(), (nb_agents,), device=device + ) - rewards, inventories, life_levels = engine.step(action) + rewards, inventories, life_levels = environment.step(action) if display: - l = engine.seq2tilepic(views.flatten(1), engine.view_width) + l = environment.state2str(state) l = [ - v + [f"{engine.action2str(a.item())}/{r: 3d}"] + v + f"\n{environment.action2str(a.item())}/{r: 3d}" for (v, a, r) in zip(l, action, rewards) ] to_print += ( - coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n" + char_conv(fusion_multi_lines(l, width_min=environment.world_width)) + + "\n" ) print(to_print)