From d507a6f064b68233b4ec14f58ca65e0a2002ac21 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 6 Nov 2023 08:06:33 +0100 Subject: [PATCH] Update. --- picocrafter.py | 49 +++++++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/picocrafter.py b/picocrafter.py index e303554..36088ac 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -96,19 +96,19 @@ class PicroCrafterEngine: world_height=27, world_width=27, nb_walls=27, - margin=2, + world_margin=2, view_height=5, view_width=5, device=torch.device("cpu"), ): - assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0 - assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0 + assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0 + assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0 self.device = device self.world_height = world_height self.world_width = world_width - self.margin = margin + self.world_margin = world_margin self.view_height = view_height self.view_width = view_width self.nb_walls = nb_walls @@ -168,7 +168,11 @@ class PicroCrafterEngine: def reset(self, nb_agents): self.worlds = self.create_worlds( - nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin + nb_agents, + self.world_height, + self.world_width, + self.nb_walls, + self.world_margin, ).to(self.device) self.life_level_in_100th = torch.full( (nb_agents,), self.life_level_max * 100 + 99, device=self.device @@ -209,9 +213,11 @@ class PicroCrafterEngine: return m - def create_worlds(self, nb, height, width, nb_walls, margin=2): - margin -= 1 # The maze adds a wall all around - m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls) + def create_worlds(self, nb, height, width, nb_walls, world_margin=2): + world_margin -= 1 # The maze adds a wall all around + m = self.create_mazes( + nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls + ) q = m.flatten(1) z = "@aAbBcC$$$$$" # What to add to the maze u = torch.rand(q.size(), device=q.device) * (1 - q) @@ -222,12 +228,12 @@ class PicroCrafterEngine: torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] - if margin > 0: + if world_margin > 0: r = m.new_full( - (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2), + (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2), self.tile2id["+"], ) - r[:, margin:-margin, margin:-margin] = m + r[:, world_margin:-world_margin, world_margin:-world_margin] = m m = r return m @@ -378,12 +384,12 @@ class PicroCrafterEngine: def views(self): i_height, i_width = ( - self.view_height - 2 * self.margin, - self.view_width - 2 * self.margin, + self.view_height - 2 * self.world_margin, + self.view_width - 2 * self.world_margin, ) a = (self.worlds == self.tile2id["@"]).nonzero() - y = i_height * ((a[:, 1] - self.margin) // i_height) - x = i_width * ((a[:, 2] - self.margin) // i_width) + y = i_height * ((a[:, 1] - self.world_margin) // i_height) + x = i_width * ((a[:, 2] - self.world_margin) // i_width) n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width) i = ( torch.arange(self.view_height, device=a.device)[None, :, None] @@ -414,7 +420,7 @@ class PicroCrafterEngine: return v - def seq2tilepic(self, t, width): + def seq2tiles(self, t, width=None): def tile(n): n = n.item() if n in self.id2tile: @@ -423,7 +429,10 @@ class PicroCrafterEngine: return "?" if t.dim() == 2: - return [self.seq2tilepic(r, width) for r in t] + return [self.seq2tiles(r, width) for r in t] + + if width is None: + width = self.view_width t = t.reshape(-1, width) @@ -458,7 +467,7 @@ if __name__ == "__main__": nb_walls=35, view_height=9, view_width=9, - margin=4, + world_margin=4, device=device, ) @@ -478,7 +487,7 @@ if __name__ == "__main__": to_print = "" os.system("clear") - l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width) + l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width) to_print += char_conv(fusion_multi_lines(l)) + "\n\n" @@ -488,7 +497,7 @@ if __name__ == "__main__": rewards, inventories, life_levels = engine.step(action) if display: - l = engine.seq2tilepic(views.flatten(1), engine.view_width) + l = engine.seq2tiles(views.flatten(1)) l = [ v + [f"{engine.action2str(a.item())}/{r: 3d}"] for (v, a, r) in zip(l, action, rewards) -- 2.39.5