Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 31 Oct 2023 16:55:45 +0000 (17:55 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 31 Oct 2023 16:55:45 +0000 (17:55 +0100)
picocrafter.py

index 31ba1e4..7810b67 100755 (executable)
@@ -79,7 +79,7 @@ class PicroCrafterEngine:
         self.reward_per_hit = -1
         self.reward_death = -10
 
-        self.tokens = " +#@$aAbBcC"
+        self.tokens = " +#@$aAbBcC."
         self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
 
@@ -92,7 +92,7 @@ class PicroCrafterEngine:
                     ("b", "B"),
                     ("B", "c"),
                     ("c", "C"),
-                    ("C", " "),
+                    ("C", "."),
                 ]
             ]
         )
@@ -111,7 +111,7 @@ class PicroCrafterEngine:
             ]
         )
 
-        self.acessible_object_to_inventory = dict(
+        self.accessible_object_to_inventory = dict(
             [
                 (self.token2id[s], self.token2id[t])
                 for (s, t) in [
@@ -120,7 +120,8 @@ class PicroCrafterEngine:
                     ("b", " "),
                     ("B", "b"),
                     ("c", " "),
-                    ("C", " "),
+                    ("C", "c"),
+                    (".", " "),
                 ]
             ]
         )
@@ -208,7 +209,6 @@ class PicroCrafterEngine:
         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
         b = a.clone()
         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
-
         # position is empty
         o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
         # or it is the next accessible object
@@ -219,6 +219,10 @@ class PicroCrafterEngine:
         b = (1 - o) * a + o * b
         self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
 
+        qq = q
+        q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
+        q[b[:, 0]] = qq
+
         nb_hits = self.monster_moves()
 
         alive_before = self.life_level_in_100th > 0
@@ -243,14 +247,14 @@ class PicroCrafterEngine:
         )
         inventory = torch.tensor(
             [
-                self.acessible_object_to_inventory[s.item()]
+                self.accessible_object_to_inventory[s.item()]
                 for s in self.accessible_object
             ]
         )
 
         self.life_level_in_100th = (
             self.life_level_in_100th
-            * (self.accessible_object != self.token2id[" "]).long()
+            * (self.accessible_object != self.token2id["."]).long()
         )
 
         reward[torch.logical_not(alive_before)] = 0
@@ -392,16 +396,19 @@ if __name__ == "__main__":
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    ansi_term = False
-    # nb_agents, nb_iter, display = 1000, 100, False
+    ansi_term = False
+    # nb_agents, nb_iter, display = 1000, 1000, False
     nb_agents, nb_iter, display = 3, 10000, True
-    ansi_term = True
+    ansi_term = True
 
     start_time = time.perf_counter()
     engine = PicroCrafterEngine(
         world_height=27,
         world_width=27,
         nb_walls=35,
+        # world_height=15,
+        # world_width=15,
+        # nb_walls=0,
         view_height=9,
         view_width=9,
         margin=4,
@@ -414,6 +421,7 @@ if __name__ == "__main__":
 
     start_time = time.perf_counter()
 
+    stop = 0
     for k in range(nb_iter):
         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
         rewards, inventories, life_levels = engine.step(
@@ -438,7 +446,9 @@ if __name__ == "__main__":
             time.sleep(0.25)
 
         if (life_levels > 0).long().sum() == 0:
-            break
+            stop += 1
+            if stop == 2:
+                break
 
     print(
         f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"