3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation. #
8 # This program is distributed in the hope that it will be useful, but #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
11 # General Public License for more details. #
13 # You should have received a copy of the GNU General Public License #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
16 # Written by and Copyright (C) Francois Fleuret #
17 # Contact <francois.fleuret@unige.ch> for comments & bug reports #
18 #########################################################################
20 # This is a tiny rogue-like environment implemented with tensor
21 # operations, that runs in batches efficiently on a GPU. On a RTX4090
22 # it can initialize ~20k environments per second and run ~40k
25 # The environment is a rectangular area with walls "#" dispatched
26 # randomly. The agent "@" can perform five actions: move NESW or do
29 # There are monsters "$" moving randomly. The agent gets hit by every
30 # monster present in one of the 4 direct neighborhoods at the end of
31 # the moves, each hit results in a rewards of -1.
33 # The agent starts with 5 life points, each hit costs it 1pt, when it
34 # gets to 0 it dies, gets a reward of -10 and the episode is over. At
35 # every step it recovers 1/20th of a life point, with a maximum of
38 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
39 # "B", "C"). The keys and vault can only be used in sequence:
40 # initially the agent can move only to free spaces, or to the "a", in
41 # which case the key is removed from the environment and the agent now
42 # carries it, and can move to free spaces or the "A". When it moves to
43 # the "A", it gets a reward, loses the "a", the "A" is removed from
44 # the environment, but the agent can now move to the "b", etc. Rewards
45 # are 1 for "A" and "B" and 10 for "C".
47 ######################################################################
51 from torch.nn.functional import conv2d
53 ######################################################################
58 return [to_ansi(x) for x in s]
60 for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
61 s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
68 return [to_unicode(x) for x in s]
70 for u, c in [("#", "█"), ("+", "░"), ("|", "│")]:
76 def fusion_multi_lines(l, width_min=0):
77 l = [x if type(x) is str else str(x) for x in l]
79 l = [x.split("\n") for x in l]
83 return " " * (k // 2) + r + " " * (k - k // 2)
86 w = max(width_min, max([len(r) for r in o]))
87 return [" " * w] * (h - len(o)) + [center(r, w) for r in o]
89 h = max([len(x) for x in l])
90 l = [f(o, h) for o in l]
92 return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
95 class PicroCrafterEnvironment:
104 device=torch.device("cpu"),
106 assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0
107 assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0
111 self.world_height = world_height
112 self.world_width = world_width
113 self.world_margin = world_margin
114 self.view_height = view_height
115 self.view_width = view_width
116 self.nb_walls = nb_walls
117 self.life_level_max = 5
118 self.life_level_gain_100th = 5
119 self.reward_per_hit = -1
120 self.reward_death = -10
122 self.tiles = " +#@$aAbBcC-" + "".join(
123 [str(n) for n in range(self.life_level_max + 1)]
125 self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)])
126 self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)])
128 self.next_object = dict(
130 (self.tile2id[s], self.tile2id[t])
142 self.object_reward = dict(
156 self.accessible_object_to_inventory = dict(
158 (self.tile2id[s], self.tile2id[t])
171 def reset(self, nb_agents):
172 self.worlds = self.create_worlds(
179 self.life_level_in_100th = torch.full(
180 (nb_agents,), self.life_level_max * 100 + 99, device=self.device
182 self.accessible_object = torch.full(
183 (nb_agents,), self.tile2id["a"], device=self.device
186 def create_mazes(self, nb, height, width, nb_walls):
187 m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
193 i = torch.arange(height, device=m.device)[None, :, None]
194 j = torch.arange(width, device=m.device)[None, None, :]
196 for _ in range(nb_walls):
197 q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
198 (1 - m) * (i % 2 == 0) * (j % 2 == 0)
200 q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
201 a = q[:, None].expand(-1, 4, -1, -1).clone()
202 a[:, 0, :-1, :] += q[:, 1:, :]
203 a[:, 0, :-2, :] += q[:, 2:, :]
204 a[:, 1, 1:, :] += q[:, :-1, :]
205 a[:, 1, 2:, :] += q[:, :-2, :]
206 a[:, 2, :, :-1] += q[:, :, 1:]
207 a[:, 2, :, :-2] += q[:, :, 2:]
208 a[:, 3, :, 1:] += q[:, :, :-1]
209 a[:, 3, :, 2:] += q[:, :, :-2]
211 torch.arange(a.size(0), device=a.device),
212 torch.randint(4, (a.size(0),), device=a.device),
214 m = (m + q + a).clamp(max=1)
218 def create_worlds(self, nb, height, width, nb_walls, world_margin=2):
219 world_margin -= 1 # The maze adds a wall all around
220 m = self.create_mazes(
221 nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls
224 z = "@aAbBcC$$$$$" # What to add to the maze
225 u = torch.rand(q.size(), device=q.device) * (1 - q)
226 r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
228 q *= self.tile2id["#"]
230 torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
231 ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
235 (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2),
238 r[:, world_margin:-world_margin, world_margin:-world_margin] = m
242 def nb_actions(self):
245 def action2str(self, n):
251 def nb_state_token_values(self):
252 return len(self.tiles)
254 def min_max_reward(self):
256 min(4 * self.reward_per_hit, self.reward_death),
257 max(self.object_reward.values()),
260 def step(self, actions):
261 a = (self.worlds == self.tile2id["@"]).nonzero()
262 self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "]
263 s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
265 b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
267 o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long()
268 # or it is the next accessible object
270 self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
272 o = (o + q).clamp(max=1)[:, None]
273 b = (1 - o) * a + o * b
274 self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"]
277 q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
280 nb_hits = self.monster_moves()
282 alive_before = self.life_level_in_100th >= 100
284 self.life_level_in_100th[alive_before] = (
285 self.life_level_in_100th[alive_before]
286 + self.life_level_gain_100th
287 - nb_hits[alive_before] * 100
288 ).clamp(max=self.life_level_max * 100 + 99)
290 alive_after = self.life_level_in_100th >= 100
292 self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
294 reward = nb_hits * self.reward_per_hit
296 for i in range(q.size(0)):
298 reward[i] += self.object_reward[self.accessible_object[i].item()]
299 self.accessible_object[i] = self.next_object[
300 self.accessible_object[i].item()
304 alive_after.long() * reward
305 + alive_before.long() * (1 - alive_after.long()) * self.reward_death
307 inventory = torch.tensor(
309 self.accessible_object_to_inventory[s.item()]
310 for s in self.accessible_object
314 self.life_level_in_100th = (
315 self.life_level_in_100th
316 * (self.accessible_object != self.tile2id["-"]).long()
319 reward[torch.logical_not(alive_before)] = 0
321 return reward, inventory, self.life_level_in_100th // 100
323 def monster_moves(self):
324 # Current positions of the monsters
325 m = (self.worlds == self.tile2id["$"]).long().flatten(1)
327 # Total number of monsters
330 # Create a tensor with one channel per monster
332 (torch.rand(m.size(), device=m.device) * m)
333 .sort(dim=-1, descending=True)
336 o = m.new_zeros((m.size(0), n) + m.size()[1:])
337 i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
338 j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
342 # Create the tensor of possible motions
343 o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
344 move_kernel = torch.tensor(
345 [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
351 o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
359 # Let's do the moves per say
360 i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
364 for n in range(p.size(1)):
365 u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
366 q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n]
368 (q * torch.rand(q.size(), device=q.device))
369 .sort(dim=-1, descending=True)
372 self.worlds.flatten(1)[i, u] = self.tile2id[" "]
373 self.worlds.flatten(1)[i, r] = self.tile2id["$"]
378 (self.worlds == self.tile2id["$"]).float()[:, None],
384 * (self.worlds == self.tile2id["@"]).long()
392 def state_size(self):
393 return (self.view_height + 1) * self.view_width
396 i_height, i_width = (
397 self.view_height - 2 * self.world_margin,
398 self.view_width - 2 * self.world_margin,
400 a = (self.worlds == self.tile2id["@"]).nonzero()
401 y = i_height * ((a[:, 1] - self.world_margin) // i_height)
402 x = i_width * ((a[:, 2] - self.world_margin) // i_width)
403 n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
405 torch.arange(self.view_height, device=a.device)[None, :, None]
409 torch.arange(self.view_width, device=a.device)[None, None, :]
412 v = self.worlds.new_full(
413 (self.worlds.size(0), self.view_height + 1, self.view_width),
417 v[a[:, 0], : self.view_height] = self.worlds[n, i, j]
419 v[:, self.view_height] = self.tile2id["-"]
420 v[:, self.view_height, 0] = self.tile2id["0"] + (
421 self.life_level_in_100th // 100
422 ).clamp(min=0, max=self.life_level_max)
423 v[:, self.view_height, 1] = torch.tensor(
425 self.accessible_object_to_inventory[o.item()]
426 for o in self.accessible_object
431 return v.flatten(1), self.life_level_in_100th >= 100
433 def state2str(self, t, width=None):
436 if n in self.id2tile:
437 return self.id2tile[n]
442 return [self.state2str(r, width) for r in t]
445 width = self.view_width
447 t = t.reshape(-1, width)
449 t = "\n".join(["".join([tile(n) for n in r]) for r in t])
454 ######################################################################
456 if __name__ == "__main__":
459 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
461 # char_conv = lambda x: x
462 char_conv = to_unicode
464 # nb_agents, nb_iter, display = 1000, 1000, False
467 nb_agents, nb_iter, display = 4, 10000, True
471 char_conv = lambda x: to_ansi(to_unicode(x))
473 start_time = time.perf_counter()
474 environment = PicroCrafterEnvironment(
484 environment.reset(nb_agents)
486 print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
488 start_time = time.perf_counter()
491 for k in range(nb_iter):
500 l = environment.state2str(
501 environment.worlds.flatten(1), width=environment.world_width
504 to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
506 state, alive = environment.state()
507 action = alive * torch.randint(
508 environment.nb_actions(), (nb_agents,), device=device
511 rewards, inventories, life_levels = environment.step(action)
514 l = environment.state2str(state)
516 v + f"\n{environment.action2str(a.item())}/{r: 3d}"
517 for (v, a, r) in zip(l, action, rewards)
521 char_conv(fusion_multi_lines(l, width_min=environment.world_width))
529 if (life_levels > 0).long().sum() == 0:
534 print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")