Update.
[pytorch.git] / picocrafter.py
index 36088ac..23d93b2 100755 (executable)
@@ -23,8 +23,8 @@
 # iterations.
 #
 # The environment is a rectangular area with walls "#" dispatched
-# randomly. The agent "@" can perform five actions: move NESW or do
-# not move.
+# randomly. The agent "@" can perform five actions: move "NESW" or be
+# immobile "I".
 #
 # There are monsters "$" moving randomly. The agent gets hit by every
 # monster present in one of the 4 direct neighborhoods at the end of
@@ -39,8 +39,9 @@
 # "B", "C"). The keys and vault can only be used in sequence:
 # initially the agent can move only to free spaces, or to the "a", in
 # which case the key is removed from the environment and the agent now
-# carries it, and can move to free spaces or the "A". When it moves to
-# the "A", it gets a reward, loses the "a", the "A" is removed from
+# carries it, it appears in the inventory at the bottom of the frame,
+# and the agent can now move to free spaces or the "A". When it moves
+# to the "A", it gets a reward, loses the "a", the "A" is removed from
 # the environment, but the agent can now move to the "b", etc. Rewards
 # are 1 for "A" and "B" and 10 for "C".
 
@@ -74,7 +75,9 @@ def to_unicode(s):
 
 
 def fusion_multi_lines(l, width_min=0):
-    l = [x if type(x) is list else [str(x)] for x in l]
+    l = [x if type(x) is str else str(x) for x in l]
+
+    l = [x.split("\n") for x in l]
 
     def center(r, w):
         k = w - len(r)
@@ -90,7 +93,7 @@ def fusion_multi_lines(l, width_min=0):
     return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
 
 
-class PicroCrafterEngine:
+class PicroCrafterEnvironment:
     def __init__(
         self,
         world_height=27,
@@ -242,11 +245,11 @@ class PicroCrafterEngine:
 
     def action2str(self, n):
         if n >= 0 and n < 5:
-            return "XNESW"[n]
+            return "INESW"[n]
         else:
             return "?"
 
-    def nb_view_tiles(self):
+    def nb_state_token_values(self):
         return len(self.tiles)
 
     def min_max_reward(self):
@@ -277,14 +280,18 @@ class PicroCrafterEngine:
 
         nb_hits = self.monster_moves()
 
-        alive_before = self.life_level_in_100th > 99
+        alive_before = self.life_level_in_100th >= 100
+
         self.life_level_in_100th[alive_before] = (
             self.life_level_in_100th[alive_before]
             + self.life_level_gain_100th
             - nb_hits[alive_before] * 100
         ).clamp(max=self.life_level_max * 100 + 99)
-        alive_after = self.life_level_in_100th > 99
+
+        alive_after = self.life_level_in_100th >= 100
+
         self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
+
         reward = nb_hits * self.reward_per_hit
 
         for i in range(q.size(0)):
@@ -311,6 +318,7 @@ class PicroCrafterEngine:
         )
 
         reward[torch.logical_not(alive_before)] = 0
+
         return reward, inventory, self.life_level_in_100th // 100
 
     def monster_moves(self):
@@ -382,7 +390,10 @@ class PicroCrafterEngine:
 
         return nb_hits
 
-    def views(self):
+    def state_size(self):
+        return (self.view_height + 1) * self.view_width
+
+    def state(self):
         i_height, i_width = (
             self.view_height - 2 * self.world_margin,
             self.view_width - 2 * self.world_margin,
@@ -418,9 +429,9 @@ class PicroCrafterEngine:
             device=v.device,
         )
 
-        return v
+        return v.flatten(1), self.life_level_in_100th >= 100
 
-    def seq2tiles(self, t, width=None):
+    def state2str(self, t, width=None):
         def tile(n):
             n = n.item()
             if n in self.id2tile:
@@ -429,14 +440,14 @@ class PicroCrafterEngine:
                 return "?"
 
         if t.dim() == 2:
-            return [self.seq2tiles(r, width) for r in t]
+            return [self.state2str(r, width) for r in t]
 
         if width is None:
             width = self.view_width
 
         t = t.reshape(-1, width)
 
-        t = ["".join([tile(n) for n in r]) for r in t]
+        t = "\n".join(["".join([tile(n) for n in r]) for r in t])
 
         return t
 
@@ -461,7 +472,7 @@ if __name__ == "__main__":
         char_conv = lambda x: to_ansi(to_unicode(x))
 
     start_time = time.perf_counter()
-    engine = PicroCrafterEngine(
+    environment = PicroCrafterEnvironment(
         world_height=27,
         world_width=27,
         nb_walls=35,
@@ -471,7 +482,7 @@ if __name__ == "__main__":
         device=device,
     )
 
-    engine.reset(nb_agents)
+    environment.reset(nb_agents)
 
     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
 
@@ -487,24 +498,29 @@ if __name__ == "__main__":
                 to_print = ""
                 os.system("clear")
 
-            l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width)
+            l = environment.state2str(
+                environment.worlds.flatten(1), width=environment.world_width
+            )
 
             to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
 
-        views = engine.views()
-        action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+        state, alive = environment.state()
+        action = alive * torch.randint(
+            environment.nb_actions(), (nb_agents,), device=device
+        )
 
-        rewards, inventories, life_levels = engine.step(action)
+        rewards, inventories, life_levels = environment.step(action)
 
         if display:
-            l = engine.seq2tiles(views.flatten(1))
+            l = environment.state2str(state)
             l = [
-                v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+                v + f"\n{environment.action2str(a.item())}/{r: 3d}"
                 for (v, a, r) in zip(l, action, rewards)
             ]
 
             to_print += (
-                char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+                char_conv(fusion_multi_lines(l, width_min=environment.world_width))
+                + "\n"
             )
 
             print(to_print)