From 2bce3f35e7f9ecbc1b5cf1d27ce313e270aa3bb2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 31 Oct 2023 17:55:45 +0100 Subject: [PATCH] Update. --- picocrafter.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/picocrafter.py b/picocrafter.py index 31ba1e4..7810b67 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -79,7 +79,7 @@ class PicroCrafterEngine: self.reward_per_hit = -1 self.reward_death = -10 - self.tokens = " +#@$aAbBcC" + self.tokens = " +#@$aAbBcC." self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)]) self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)]) @@ -92,7 +92,7 @@ class PicroCrafterEngine: ("b", "B"), ("B", "c"), ("c", "C"), - ("C", " "), + ("C", "."), ] ] ) @@ -111,7 +111,7 @@ class PicroCrafterEngine: ] ) - self.acessible_object_to_inventory = dict( + self.accessible_object_to_inventory = dict( [ (self.token2id[s], self.token2id[t]) for (s, t) in [ @@ -120,7 +120,8 @@ class PicroCrafterEngine: ("b", " "), ("B", "b"), ("c", " "), - ("C", " "), + ("C", "c"), + (".", " "), ] ] ) @@ -208,7 +209,6 @@ class PicroCrafterEngine: s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device) b = a.clone() b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]] - # position is empty o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long() # or it is the next accessible object @@ -219,6 +219,10 @@ class PicroCrafterEngine: b = (1 - o) * a + o * b self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"] + qq = q + q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:]) + q[b[:, 0]] = qq + nb_hits = self.monster_moves() alive_before = self.life_level_in_100th > 0 @@ -243,14 +247,14 @@ class PicroCrafterEngine: ) inventory = torch.tensor( [ - self.acessible_object_to_inventory[s.item()] + self.accessible_object_to_inventory[s.item()] for s in self.accessible_object ] ) self.life_level_in_100th = ( self.life_level_in_100th - * (self.accessible_object != self.token2id[" "]).long() + * (self.accessible_object != self.token2id["."]).long() ) reward[torch.logical_not(alive_before)] = 0 @@ -392,16 +396,19 @@ if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ansi_term = False - # nb_agents, nb_iter, display = 1000, 100, False + # ansi_term = False + # nb_agents, nb_iter, display = 1000, 1000, False nb_agents, nb_iter, display = 3, 10000, True - # ansi_term = True + ansi_term = True start_time = time.perf_counter() engine = PicroCrafterEngine( world_height=27, world_width=27, nb_walls=35, + # world_height=15, + # world_width=15, + # nb_walls=0, view_height=9, view_width=9, margin=4, @@ -414,6 +421,7 @@ if __name__ == "__main__": start_time = time.perf_counter() + stop = 0 for k in range(nb_iter): action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) rewards, inventories, life_levels = engine.step( @@ -438,7 +446,9 @@ if __name__ == "__main__": time.sleep(0.25) if (life_levels > 0).long().sum() == 0: - break + stop += 1 + if stop == 2: + break print( f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s" -- 2.39.5