Update.
[pytorch.git] / picocrafter.py
index ef861bc..23d93b2 100755 (executable)
@@ -22,8 +22,9 @@
 # it can initialize ~20k environments per second and run ~40k
 # iterations.
 #
-# The agent "@" moves in a maze-like grid with random walls "#". There
-# are five actions: move NESW or do not move.
+# The environment is a rectangular area with walls "#" dispatched
+# randomly. The agent "@" can perform five actions: move "NESW" or be
+# immobile "I".
 #
 # There are monsters "$" moving randomly. The agent gets hit by every
 # monster present in one of the 4 direct neighborhoods at the end of
 # "B", "C"). The keys and vault can only be used in sequence:
 # initially the agent can move only to free spaces, or to the "a", in
 # which case the key is removed from the environment and the agent now
-# carries it, and can move to free spaces or the "A". When it moves to
-# the "A", it gets a reward, loses the "a", the "A" is removed from
-# the environment, but can now move to the "b", etc. Rewards are 1 for
-# "A" and "B" and 10 for "C".
+# carries it, it appears in the inventory at the bottom of the frame,
+# and the agent can now move to free spaces or the "A". When it moves
+# to the "A", it gets a reward, loses the "a", the "A" is removed from
+# the environment, but the agent can now move to the "b", etc. Rewards
+# are 1 for "A" and "B" and 10 for "C".
 
 ######################################################################
 
@@ -52,22 +54,38 @@ from torch.nn.functional import conv2d
 ######################################################################
 
 
-def add_ansi_coloring(s):
+def to_ansi(s):
     if type(s) is list:
-        return [add_ansi_coloring(x) for x in s]
+        return [to_ansi(x) for x in s]
 
-    for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
+    for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
         s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
 
     return s
 
 
+def to_unicode(s):
+    if type(s) is list:
+        return [to_unicode(x) for x in s]
+
+    for u, c in [("#", "█"), ("+", "░"), ("|", "│")]:
+        s = s.replace(u, c)
+
+    return s
+
+
 def fusion_multi_lines(l, width_min=0):
-    l = [x if type(x) is list else [str(x)] for x in l]
+    l = [x if type(x) is str else str(x) for x in l]
+
+    l = [x.split("\n") for x in l]
+
+    def center(r, w):
+        k = w - len(r)
+        return " " * (k // 2) + r + " " * (k - k // 2)
 
     def f(o, h):
         w = max(width_min, max([len(r) for r in o]))
-        return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o]
+        return [" " * w] * (h - len(o)) + [center(r, w) for r in o]
 
     h = max([len(x) for x in l])
     l = [f(o, h) for o in l]
@@ -75,25 +93,25 @@ def fusion_multi_lines(l, width_min=0):
     return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
 
 
-class PicroCrafterEngine:
+class PicroCrafterEnvironment:
     def __init__(
         self,
         world_height=27,
         world_width=27,
         nb_walls=27,
-        margin=2,
+        world_margin=2,
         view_height=5,
         view_width=5,
         device=torch.device("cpu"),
     ):
-        assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
-        assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
+        assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0
+        assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0
 
         self.device = device
 
         self.world_height = world_height
         self.world_width = world_width
-        self.margin = margin
+        self.world_margin = world_margin
         self.view_height = view_height
         self.view_width = view_width
         self.nb_walls = nb_walls
@@ -153,7 +171,11 @@ class PicroCrafterEngine:
 
     def reset(self, nb_agents):
         self.worlds = self.create_worlds(
-            nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
+            nb_agents,
+            self.world_height,
+            self.world_width,
+            self.nb_walls,
+            self.world_margin,
         ).to(self.device)
         self.life_level_in_100th = torch.full(
             (nb_agents,), self.life_level_max * 100 + 99, device=self.device
@@ -194,9 +216,11 @@ class PicroCrafterEngine:
 
         return m
 
-    def create_worlds(self, nb, height, width, nb_walls, margin=2):
-        margin -= 1  # The maze adds a wall all around
-        m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
+    def create_worlds(self, nb, height, width, nb_walls, world_margin=2):
+        world_margin -= 1  # The maze adds a wall all around
+        m = self.create_mazes(
+            nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls
+        )
         q = m.flatten(1)
         z = "@aAbBcC$$$$$"  # What to add to the maze
         u = torch.rand(q.size(), device=q.device) * (1 - q)
@@ -207,12 +231,12 @@ class PicroCrafterEngine:
             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
         ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
 
-        if margin > 0:
+        if world_margin > 0:
             r = m.new_full(
-                (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
+                (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2),
                 self.tile2id["+"],
             )
-            r[:, margin:-margin, margin:-margin] = m
+            r[:, world_margin:-world_margin, world_margin:-world_margin] = m
             m = r
         return m
 
@@ -221,11 +245,11 @@ class PicroCrafterEngine:
 
     def action2str(self, n):
         if n >= 0 and n < 5:
-            return "XNESW"[n]
+            return "INESW"[n]
         else:
             return "?"
 
-    def nb_view_tiles(self):
+    def nb_state_token_values(self):
         return len(self.tiles)
 
     def min_max_reward(self):
@@ -256,14 +280,18 @@ class PicroCrafterEngine:
 
         nb_hits = self.monster_moves()
 
-        alive_before = self.life_level_in_100th > 99
+        alive_before = self.life_level_in_100th >= 100
+
         self.life_level_in_100th[alive_before] = (
             self.life_level_in_100th[alive_before]
             + self.life_level_gain_100th
             - nb_hits[alive_before] * 100
         ).clamp(max=self.life_level_max * 100 + 99)
-        alive_after = self.life_level_in_100th > 99
+
+        alive_after = self.life_level_in_100th >= 100
+
         self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
+
         reward = nb_hits * self.reward_per_hit
 
         for i in range(q.size(0)):
@@ -290,6 +318,7 @@ class PicroCrafterEngine:
         )
 
         reward[torch.logical_not(alive_before)] = 0
+
         return reward, inventory, self.life_level_in_100th // 100
 
     def monster_moves(self):
@@ -361,14 +390,17 @@ class PicroCrafterEngine:
 
         return nb_hits
 
-    def views(self):
+    def state_size(self):
+        return (self.view_height + 1) * self.view_width
+
+    def state(self):
         i_height, i_width = (
-            self.view_height - 2 * self.margin,
-            self.view_width - 2 * self.margin,
+            self.view_height - 2 * self.world_margin,
+            self.view_width - 2 * self.world_margin,
         )
         a = (self.worlds == self.tile2id["@"]).nonzero()
-        y = i_height * ((a[:, 1] - self.margin) // i_height)
-        x = i_width * ((a[:, 2] - self.margin) // i_width)
+        y = i_height * ((a[:, 1] - self.world_margin) // i_height)
+        x = i_width * ((a[:, 2] - self.world_margin) // i_width)
         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
         i = (
             torch.arange(self.view_height, device=a.device)[None, :, None]
@@ -397,9 +429,9 @@ class PicroCrafterEngine:
             device=v.device,
         )
 
-        return v
+        return v.flatten(1), self.life_level_in_100th >= 100
 
-    def seq2tilepic(self, t, width):
+    def state2str(self, t, width=None):
         def tile(n):
             n = n.item()
             if n in self.id2tile:
@@ -408,49 +440,17 @@ class PicroCrafterEngine:
                 return "?"
 
         if t.dim() == 2:
-            return [self.seq2tilepic(r, width) for r in t]
+            return [self.state2str(r, width) for r in t]
+
+        if width is None:
+            width = self.view_width
 
         t = t.reshape(-1, width)
 
-        t = ["".join([tile(n) for n in r]) for r in t]
+        t = "\n".join(["".join([tile(n) for n in r]) for r in t])
 
         return t
 
-    def print_worlds(
-        self, src=None, comments=[], width=None, printer=print, ansi_term=False
-    ):
-        if src is None:
-            src = list(self.worlds)
-
-        height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
-
-        def tile(n):
-            n = n.item()
-            if n in self.id2tile:
-                return self.id2tile[n]
-            else:
-                return "?"
-
-        for k in range(height):
-
-            def f(x):
-                if torch.is_tensor(x):
-                    if x.dim() == 0:
-                        x = str(x.item())
-                        return " " * len(x) if k < height - 1 else x
-                    else:
-                        s = "".join([tile(n) for n in x[k]])
-                        if ansi_term:
-                            for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
-                                (x, 36) for x in "aAbBcC"
-                            ]:
-                                s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
-                        return s
-                else:
-                    return " " * len(x) if k < height - 1 else x
-
-            printer("|".join([f(x) for x in src]))
-
 
 ######################################################################
 
@@ -459,36 +459,35 @@ if __name__ == "__main__":
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    # nb_agents, nb_iter, display = 10000, 100, False
+    # char_conv = lambda x: x
+    char_conv = to_unicode
+
+    # nb_agents, nb_iter, display = 1000, 1000, False
     # ansi_term = False
+
     nb_agents, nb_iter, display = 4, 10000, True
     ansi_term = True
 
+    if ansi_term:
+        char_conv = lambda x: to_ansi(to_unicode(x))
+
     start_time = time.perf_counter()
-    engine = PicroCrafterEngine(
+    environment = PicroCrafterEnvironment(
         world_height=27,
         world_width=27,
         nb_walls=35,
-        # world_height=15,
-        # world_width=15,
-        # nb_walls=0,
         view_height=9,
         view_width=9,
-        margin=4,
+        world_margin=4,
         device=device,
     )
 
-    engine.reset(nb_agents)
+    environment.reset(nb_agents)
 
     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
 
     start_time = time.perf_counter()
 
-    if ansi_term:
-        coloring = add_ansi_coloring
-    else:
-        coloring = lambda x: x
-
     stop = 0
     for k in range(nb_iter):
         if display:
@@ -499,24 +498,29 @@ if __name__ == "__main__":
                 to_print = ""
                 os.system("clear")
 
-            l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
+            l = environment.state2str(
+                environment.worlds.flatten(1), width=environment.world_width
+            )
 
-            to_print += coloring(fusion_multi_lines(l)) + "\n\n"
+            to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
 
-        views = engine.views()
-        action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+        state, alive = environment.state()
+        action = alive * torch.randint(
+            environment.nb_actions(), (nb_agents,), device=device
+        )
 
-        rewards, inventories, life_levels = engine.step(action)
+        rewards, inventories, life_levels = environment.step(action)
 
         if display:
-            l = engine.seq2tilepic(views.flatten(1), engine.view_width)
+            l = environment.state2str(state)
             l = [
-                v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+                v + f"\n{environment.action2str(a.item())}/{r: 3d}"
                 for (v, a, r) in zip(l, action, rewards)
             ]
 
             to_print += (
-                coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+                char_conv(fusion_multi_lines(l, width_min=environment.world_width))
+                + "\n"
             )
 
             print(to_print)