Update.
[pytorch.git] / picocrafter.py
index 33a00c1..7810b67 100755 (executable)
 # 5pt.
 #
 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
-# "B", "C"). They keys can only be used in sequence: initially the
-# agent can move only to free spaces, or to the "a", in which case it
-# now carries it, and can move to free spaces or the "A". When it
-# moves to the "A", it gets a reward and loses the "a", but can now
-# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C".
+# "B", "C"). The keys and vault can only be used in sequence:
+# initially the agent can move only to free spaces, or to the "a", in
+# which case the key is removed from the environment and the agent now
+# carries it, and can move to free spaces or the "A". When it moves to
+# the "A", it gets a reward, loses the "a", the "A" is removed from
+# the environment, but can now move to the "b", etc. Rewards are 1 for
+# "A" and "B" and 10 for "C".
 
 ######################################################################
 
@@ -77,7 +79,7 @@ class PicroCrafterEngine:
         self.reward_per_hit = -1
         self.reward_death = -10
 
-        self.tokens = " +#@$aAbBcC"
+        self.tokens = " +#@$aAbBcC."
         self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
 
@@ -90,6 +92,7 @@ class PicroCrafterEngine:
                     ("b", "B"),
                     ("B", "c"),
                     ("c", "C"),
+                    ("C", "."),
                 ]
             ]
         )
@@ -108,7 +111,7 @@ class PicroCrafterEngine:
             ]
         )
 
-        self.acessible_object_to_inventory = dict(
+        self.accessible_object_to_inventory = dict(
             [
                 (self.token2id[s], self.token2id[t])
                 for (s, t) in [
@@ -117,7 +120,8 @@ class PicroCrafterEngine:
                     ("b", " "),
                     ("B", "b"),
                     ("c", " "),
-                    ("C", " "),
+                    ("C", "c"),
+                    (".", " "),
                 ]
             ]
         )
@@ -205,7 +209,6 @@ class PicroCrafterEngine:
         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
         b = a.clone()
         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
-
         # position is empty
         o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
         # or it is the next accessible object
@@ -216,6 +219,10 @@ class PicroCrafterEngine:
         b = (1 - o) * a + o * b
         self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
 
+        qq = q
+        q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
+        q[b[:, 0]] = qq
+
         nb_hits = self.monster_moves()
 
         alive_before = self.life_level_in_100th > 0
@@ -240,11 +247,16 @@ class PicroCrafterEngine:
         )
         inventory = torch.tensor(
             [
-                self.acessible_object_to_inventory[s.item()]
+                self.accessible_object_to_inventory[s.item()]
                 for s in self.accessible_object
             ]
         )
 
+        self.life_level_in_100th = (
+            self.life_level_in_100th
+            * (self.accessible_object != self.token2id["."]).long()
+        )
+
         reward[torch.logical_not(alive_before)] = 0
         return reward, inventory, self.life_level_in_100th // 100
 
@@ -384,8 +396,8 @@ if __name__ == "__main__":
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    ansi_term = False
-    # nb_agents, nb_iter, display = 1000, 100, False
+    ansi_term = False
+    # nb_agents, nb_iter, display = 1000, 1000, False
     nb_agents, nb_iter, display = 3, 10000, True
     ansi_term = True
 
@@ -394,6 +406,9 @@ if __name__ == "__main__":
         world_height=27,
         world_width=27,
         nb_walls=35,
+        # world_height=15,
+        # world_width=15,
+        # nb_walls=0,
         view_height=9,
         view_width=9,
         margin=4,
@@ -406,6 +421,7 @@ if __name__ == "__main__":
 
     start_time = time.perf_counter()
 
+    stop = 0
     for k in range(nb_iter):
         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
         rewards, inventories, life_levels = engine.step(
@@ -427,10 +443,12 @@ if __name__ == "__main__":
                 width=engine.world_width,
                 ansi_term=ansi_term,
             )
-            time.sleep(0.5)
+            time.sleep(0.25)
 
         if (life_levels > 0).long().sum() == 0:
-            break
+            stop += 1
+            if stop == 2:
+                break
 
     print(
         f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"