Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 12 Nov 2023 07:14:52 +0000 (08:14 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 12 Nov 2023 07:14:52 +0000 (08:14 +0100)
picocrafter.py

index 36088ac..5bd6a48 100755 (executable)
@@ -74,7 +74,9 @@ def to_unicode(s):
 
 
 def fusion_multi_lines(l, width_min=0):
-    l = [x if type(x) is list else [str(x)] for x in l]
+    l = [x if type(x) is str else str(x) for x in l]
+
+    l = [x.split("\n") for x in l]
 
     def center(r, w):
         k = w - len(r)
@@ -90,7 +92,7 @@ def fusion_multi_lines(l, width_min=0):
     return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
 
 
-class PicroCrafterEngine:
+class PicroCrafterEnvironment:
     def __init__(
         self,
         world_height=27,
@@ -246,7 +248,7 @@ class PicroCrafterEngine:
         else:
             return "?"
 
-    def nb_view_tiles(self):
+    def nb_state_token_values(self):
         return len(self.tiles)
 
     def min_max_reward(self):
@@ -277,14 +279,18 @@ class PicroCrafterEngine:
 
         nb_hits = self.monster_moves()
 
-        alive_before = self.life_level_in_100th > 99
+        alive_before = self.life_level_in_100th >= 100
+
         self.life_level_in_100th[alive_before] = (
             self.life_level_in_100th[alive_before]
             + self.life_level_gain_100th
             - nb_hits[alive_before] * 100
         ).clamp(max=self.life_level_max * 100 + 99)
-        alive_after = self.life_level_in_100th > 99
+
+        alive_after = self.life_level_in_100th >= 100
+
         self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
+
         reward = nb_hits * self.reward_per_hit
 
         for i in range(q.size(0)):
@@ -311,6 +317,7 @@ class PicroCrafterEngine:
         )
 
         reward[torch.logical_not(alive_before)] = 0
+
         return reward, inventory, self.life_level_in_100th // 100
 
     def monster_moves(self):
@@ -382,7 +389,10 @@ class PicroCrafterEngine:
 
         return nb_hits
 
-    def views(self):
+    def state_size(self):
+        return (self.view_height + 1) * self.view_width
+
+    def state(self):
         i_height, i_width = (
             self.view_height - 2 * self.world_margin,
             self.view_width - 2 * self.world_margin,
@@ -418,9 +428,9 @@ class PicroCrafterEngine:
             device=v.device,
         )
 
-        return v
+        return v.flatten(1), self.life_level_in_100th >= 100
 
-    def seq2tiles(self, t, width=None):
+    def state2str(self, t, width=None):
         def tile(n):
             n = n.item()
             if n in self.id2tile:
@@ -429,14 +439,14 @@ class PicroCrafterEngine:
                 return "?"
 
         if t.dim() == 2:
-            return [self.seq2tiles(r, width) for r in t]
+            return [self.state2str(r, width) for r in t]
 
         if width is None:
             width = self.view_width
 
         t = t.reshape(-1, width)
 
-        t = ["".join([tile(n) for n in r]) for r in t]
+        t = "\n".join(["".join([tile(n) for n in r]) for r in t])
 
         return t
 
@@ -461,7 +471,7 @@ if __name__ == "__main__":
         char_conv = lambda x: to_ansi(to_unicode(x))
 
     start_time = time.perf_counter()
-    engine = PicroCrafterEngine(
+    environment = PicroCrafterEnvironment(
         world_height=27,
         world_width=27,
         nb_walls=35,
@@ -471,7 +481,7 @@ if __name__ == "__main__":
         device=device,
     )
 
-    engine.reset(nb_agents)
+    environment.reset(nb_agents)
 
     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
 
@@ -487,24 +497,29 @@ if __name__ == "__main__":
                 to_print = ""
                 os.system("clear")
 
-            l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width)
+            l = environment.state2str(
+                environment.worlds.flatten(1), width=environment.world_width
+            )
 
             to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
 
-        views = engine.views()
-        action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+        state, alive = environment.state()
+        action = alive * torch.randint(
+            environment.nb_actions(), (nb_agents,), device=device
+        )
 
-        rewards, inventories, life_levels = engine.step(action)
+        rewards, inventories, life_levels = environment.step(action)
 
         if display:
-            l = engine.seq2tiles(views.flatten(1))
+            l = environment.state2str(state)
             l = [
-                v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+                v + f"\n{environment.action2str(a.item())}/{r: 3d}"
                 for (v, a, r) in zip(l, action, rewards)
             ]
 
             to_print += (
-                char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+                char_conv(fusion_multi_lines(l, width_min=environment.world_width))
+                + "\n"
             )
 
             print(to_print)