From aaa1fc7aa021c820ca85c5336726d483db57074a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 4 Nov 2023 12:12:30 +0100 Subject: [PATCH] Update. --- picocrafter.py | 87 +++++++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 54 deletions(-) diff --git a/picocrafter.py b/picocrafter.py index ef861bc..e303554 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -22,8 +22,9 @@ # it can initialize ~20k environments per second and run ~40k # iterations. # -# The agent "@" moves in a maze-like grid with random walls "#". There -# are five actions: move NESW or do not move. +# The environment is a rectangular area with walls "#" dispatched +# randomly. The agent "@" can perform five actions: move NESW or do +# not move. # # There are monsters "$" moving randomly. The agent gets hit by every # monster present in one of the 4 direct neighborhoods at the end of @@ -40,8 +41,8 @@ # which case the key is removed from the environment and the agent now # carries it, and can move to free spaces or the "A". When it moves to # the "A", it gets a reward, loses the "a", the "A" is removed from -# the environment, but can now move to the "b", etc. Rewards are 1 for -# "A" and "B" and 10 for "C". +# the environment, but the agent can now move to the "b", etc. Rewards +# are 1 for "A" and "B" and 10 for "C". ###################################################################### @@ -52,22 +53,36 @@ from torch.nn.functional import conv2d ###################################################################### -def add_ansi_coloring(s): +def to_ansi(s): if type(s) is list: - return [add_ansi_coloring(x) for x in s] + return [to_ansi(x) for x in s] - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: + for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") return s +def to_unicode(s): + if type(s) is list: + return [to_unicode(x) for x in s] + + for u, c in [("#", "█"), ("+", "░"), ("|", "│")]: + s = s.replace(u, c) + + return s + + def fusion_multi_lines(l, width_min=0): l = [x if type(x) is list else [str(x)] for x in l] + def center(r, w): + k = w - len(r) + return " " * (k // 2) + r + " " * (k - k // 2) + def f(o, h): w = max(width_min, max([len(r) for r in o])) - return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o] + return [" " * w] * (h - len(o)) + [center(r, w) for r in o] h = max([len(x) for x in l]) l = [f(o, h) for o in l] @@ -416,41 +431,6 @@ class PicroCrafterEngine: return t - def print_worlds( - self, src=None, comments=[], width=None, printer=print, ansi_term=False - ): - if src is None: - src = list(self.worlds) - - height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src]) - - def tile(n): - n = n.item() - if n in self.id2tile: - return self.id2tile[n] - else: - return "?" - - for k in range(height): - - def f(x): - if torch.is_tensor(x): - if x.dim() == 0: - x = str(x.item()) - return " " * len(x) if k < height - 1 else x - else: - s = "".join([tile(n) for n in x[k]]) - if ansi_term: - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ - (x, 36) for x in "aAbBcC" - ]: - s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") - return s - else: - return " " * len(x) if k < height - 1 else x - - printer("|".join([f(x) for x in src])) - ###################################################################### @@ -459,19 +439,23 @@ if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # nb_agents, nb_iter, display = 10000, 100, False + # char_conv = lambda x: x + char_conv = to_unicode + + # nb_agents, nb_iter, display = 1000, 1000, False # ansi_term = False + nb_agents, nb_iter, display = 4, 10000, True ansi_term = True + if ansi_term: + char_conv = lambda x: to_ansi(to_unicode(x)) + start_time = time.perf_counter() engine = PicroCrafterEngine( world_height=27, world_width=27, nb_walls=35, - # world_height=15, - # world_width=15, - # nb_walls=0, view_height=9, view_width=9, margin=4, @@ -484,11 +468,6 @@ if __name__ == "__main__": start_time = time.perf_counter() - if ansi_term: - coloring = add_ansi_coloring - else: - coloring = lambda x: x - stop = 0 for k in range(nb_iter): if display: @@ -501,7 +480,7 @@ if __name__ == "__main__": l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width) - to_print += coloring(fusion_multi_lines(l)) + "\n\n" + to_print += char_conv(fusion_multi_lines(l)) + "\n\n" views = engine.views() action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) @@ -516,7 +495,7 @@ if __name__ == "__main__": ] to_print += ( - coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n" + char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n" ) print(to_print) -- 2.39.5