Update.
[pytorch.git] / picocrafter.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@unige.ch> for comments & bug reports        #
18 #########################################################################
19
20 # This is a tiny rogue-like environment implemented with tensor
21 # operations, that runs in batches efficiently on a GPU. On a RTX4090
22 # it can initialize ~20k environments per second and run ~40k
23 # iterations.
24 #
25 # The agent "@" moves in a maze-like grid with random walls "#". There
26 # are five actions: move NESW or do not move.
27 #
28 # There are monsters "$" moving randomly. The agent gets hit by every
29 # monster present in one of the 4 direct neighborhoods at the end of
30 # the moves, each hit results in a rewards of -1.
31 #
32 # The agent starts with 5 life points, each hit costs it 1pt, when it
33 # gets to 0 it dies, gets a reward of -10 and the episode is over. At
34 # every step it recovers 1/20th of a life point, with a maximum of
35 # 5pt.
36 #
37 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
38 # "B", "C"). The keys and vault can only be used in sequence:
39 # initially the agent can move only to free spaces, or to the "a", in
40 # which case the key is removed from the environment and the agent now
41 # carries it, and can move to free spaces or the "A". When it moves to
42 # the "A", it gets a reward, loses the "a", the "A" is removed from
43 # the environment, but can now move to the "b", etc. Rewards are 1 for
44 # "A" and "B" and 10 for "C".
45
46 ######################################################################
47
48 import torch
49
50 from torch.nn.functional import conv2d
51
52 ######################################################################
53
54
55 def add_ansi_coloring(s):
56     if type(s) is list:
57         return [add_ansi_coloring(x) for x in s]
58
59     for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
60         s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
61
62     return s
63
64
65 def fusion_multi_lines(l, width_min=0):
66     l = [x if type(x) is list else [str(x)] for x in l]
67
68     def f(o, h):
69         w = max(width_min, max([len(r) for r in o]))
70         return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o]
71
72     h = max([len(x) for x in l])
73     l = [f(o, h) for o in l]
74
75     return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
76
77
78 class PicroCrafterEngine:
79     def __init__(
80         self,
81         world_height=27,
82         world_width=27,
83         nb_walls=27,
84         margin=2,
85         view_height=5,
86         view_width=5,
87         device=torch.device("cpu"),
88     ):
89         assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
90         assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
91
92         self.device = device
93
94         self.world_height = world_height
95         self.world_width = world_width
96         self.margin = margin
97         self.view_height = view_height
98         self.view_width = view_width
99         self.nb_walls = nb_walls
100         self.life_level_max = 5
101         self.life_level_gain_100th = 5
102         self.reward_per_hit = -1
103         self.reward_death = -10
104
105         self.tiles = " +#@$aAbBcC-" + "".join(
106             [str(n) for n in range(self.life_level_max + 1)]
107         )
108         self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)])
109         self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)])
110
111         self.next_object = dict(
112             [
113                 (self.tile2id[s], self.tile2id[t])
114                 for (s, t) in [
115                     ("a", "A"),
116                     ("A", "b"),
117                     ("b", "B"),
118                     ("B", "c"),
119                     ("c", "C"),
120                     ("C", "-"),
121                 ]
122             ]
123         )
124
125         self.object_reward = dict(
126             [
127                 (self.tile2id[t], r)
128                 for (t, r) in [
129                     ("a", 0),
130                     ("A", 1),
131                     ("b", 0),
132                     ("B", 1),
133                     ("c", 0),
134                     ("C", 10),
135                 ]
136             ]
137         )
138
139         self.accessible_object_to_inventory = dict(
140             [
141                 (self.tile2id[s], self.tile2id[t])
142                 for (s, t) in [
143                     ("a", " "),
144                     ("A", "a"),
145                     ("b", " "),
146                     ("B", "b"),
147                     ("c", " "),
148                     ("C", "c"),
149                     ("-", " "),
150                 ]
151             ]
152         )
153
154     def reset(self, nb_agents):
155         self.worlds = self.create_worlds(
156             nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
157         ).to(self.device)
158         self.life_level_in_100th = torch.full(
159             (nb_agents,), self.life_level_max * 100 + 99, device=self.device
160         )
161         self.accessible_object = torch.full(
162             (nb_agents,), self.tile2id["a"], device=self.device
163         )
164
165     def create_mazes(self, nb, height, width, nb_walls):
166         m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
167         m[:, 0, :] = 1
168         m[:, -1, :] = 1
169         m[:, :, 0] = 1
170         m[:, :, -1] = 1
171
172         i = torch.arange(height, device=m.device)[None, :, None]
173         j = torch.arange(width, device=m.device)[None, None, :]
174
175         for _ in range(nb_walls):
176             q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
177                 (1 - m) * (i % 2 == 0) * (j % 2 == 0)
178             ).flatten(1)
179             q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
180             a = q[:, None].expand(-1, 4, -1, -1).clone()
181             a[:, 0, :-1, :] += q[:, 1:, :]
182             a[:, 0, :-2, :] += q[:, 2:, :]
183             a[:, 1, 1:, :] += q[:, :-1, :]
184             a[:, 1, 2:, :] += q[:, :-2, :]
185             a[:, 2, :, :-1] += q[:, :, 1:]
186             a[:, 2, :, :-2] += q[:, :, 2:]
187             a[:, 3, :, 1:] += q[:, :, :-1]
188             a[:, 3, :, 2:] += q[:, :, :-2]
189             a = a[
190                 torch.arange(a.size(0), device=a.device),
191                 torch.randint(4, (a.size(0),), device=a.device),
192             ]
193             m = (m + q + a).clamp(max=1)
194
195         return m
196
197     def create_worlds(self, nb, height, width, nb_walls, margin=2):
198         margin -= 1  # The maze adds a wall all around
199         m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
200         q = m.flatten(1)
201         z = "@aAbBcC$$$$$"  # What to add to the maze
202         u = torch.rand(q.size(), device=q.device) * (1 - q)
203         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
204
205         q *= self.tile2id["#"]
206         q[
207             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
208         ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
209
210         if margin > 0:
211             r = m.new_full(
212                 (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
213                 self.tile2id["+"],
214             )
215             r[:, margin:-margin, margin:-margin] = m
216             m = r
217         return m
218
219     def nb_actions(self):
220         return 5
221
222     def action2str(self, n):
223         if n >= 0 and n < 5:
224             return "XNESW"[n]
225         else:
226             return "?"
227
228     def nb_view_tiles(self):
229         return len(self.tiles)
230
231     def min_max_reward(self):
232         return (
233             min(4 * self.reward_per_hit, self.reward_death),
234             max(self.object_reward.values()),
235         )
236
237     def step(self, actions):
238         a = (self.worlds == self.tile2id["@"]).nonzero()
239         self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "]
240         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
241         b = a.clone()
242         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
243         # position is empty
244         o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long()
245         # or it is the next accessible object
246         q = (
247             self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
248         ).long()
249         o = (o + q).clamp(max=1)[:, None]
250         b = (1 - o) * a + o * b
251         self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"]
252
253         qq = q
254         q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
255         q[b[:, 0]] = qq
256
257         nb_hits = self.monster_moves()
258
259         alive_before = self.life_level_in_100th > 99
260         self.life_level_in_100th[alive_before] = (
261             self.life_level_in_100th[alive_before]
262             + self.life_level_gain_100th
263             - nb_hits[alive_before] * 100
264         ).clamp(max=self.life_level_max * 100 + 99)
265         alive_after = self.life_level_in_100th > 99
266         self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
267         reward = nb_hits * self.reward_per_hit
268
269         for i in range(q.size(0)):
270             if q[i] == 1:
271                 reward[i] += self.object_reward[self.accessible_object[i].item()]
272                 self.accessible_object[i] = self.next_object[
273                     self.accessible_object[i].item()
274                 ]
275
276         reward = (
277             alive_after.long() * reward
278             + alive_before.long() * (1 - alive_after.long()) * self.reward_death
279         )
280         inventory = torch.tensor(
281             [
282                 self.accessible_object_to_inventory[s.item()]
283                 for s in self.accessible_object
284             ]
285         )
286
287         self.life_level_in_100th = (
288             self.life_level_in_100th
289             * (self.accessible_object != self.tile2id["-"]).long()
290         )
291
292         reward[torch.logical_not(alive_before)] = 0
293         return reward, inventory, self.life_level_in_100th // 100
294
295     def monster_moves(self):
296         # Current positions of the monsters
297         m = (self.worlds == self.tile2id["$"]).long().flatten(1)
298
299         # Total number of monsters
300         n = m.sum(-1).max()
301
302         # Create a tensor with one channel per monster
303         r = (
304             (torch.rand(m.size(), device=m.device) * m)
305             .sort(dim=-1, descending=True)
306             .indices[:, :n]
307         )
308         o = m.new_zeros((m.size(0), n) + m.size()[1:])
309         i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
310         j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
311         o[i, j, r] = 1
312         o = o * m[:, None]
313
314         # Create the tensor of possible motions
315         o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
316         move_kernel = torch.tensor(
317             [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
318         )
319
320         p = (
321             conv2d(
322                 o.view(
323                     o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
324                 ).float(),
325                 move_kernel,
326                 padding=1,
327             ).view(o.size())
328             == 1.0
329         ).long()
330
331         # Let's do the moves per say
332         i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
333             :, None
334         ].expand_as(r)
335
336         for n in range(p.size(1)):
337             u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
338             q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n]
339             r = (
340                 (q * torch.rand(q.size(), device=q.device))
341                 .sort(dim=-1, descending=True)
342                 .indices[:, :1]
343             )
344             self.worlds.flatten(1)[i, u] = self.tile2id[" "]
345             self.worlds.flatten(1)[i, r] = self.tile2id["$"]
346
347         nb_hits = (
348             (
349                 conv2d(
350                     (self.worlds == self.tile2id["$"]).float()[:, None],
351                     move_kernel,
352                     padding=1,
353                 )
354                 .long()
355                 .squeeze(1)
356                 * (self.worlds == self.tile2id["@"]).long()
357             )
358             .flatten(1)
359             .sum(-1)
360         )
361
362         return nb_hits
363
364     def views(self):
365         i_height, i_width = (
366             self.view_height - 2 * self.margin,
367             self.view_width - 2 * self.margin,
368         )
369         a = (self.worlds == self.tile2id["@"]).nonzero()
370         y = i_height * ((a[:, 1] - self.margin) // i_height)
371         x = i_width * ((a[:, 2] - self.margin) // i_width)
372         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
373         i = (
374             torch.arange(self.view_height, device=a.device)[None, :, None]
375             + y[:, None, None]
376         ).expand_as(n)
377         j = (
378             torch.arange(self.view_width, device=a.device)[None, None, :]
379             + x[:, None, None]
380         ).expand_as(n)
381         v = self.worlds.new_full(
382             (self.worlds.size(0), self.view_height + 1, self.view_width),
383             self.tile2id["#"],
384         )
385
386         v[a[:, 0], : self.view_height] = self.worlds[n, i, j]
387
388         v[:, self.view_height] = self.tile2id["-"]
389         v[:, self.view_height, 0] = self.tile2id["0"] + (
390             self.life_level_in_100th // 100
391         ).clamp(min=0, max=self.life_level_max)
392         v[:, self.view_height, 1] = torch.tensor(
393             [
394                 self.accessible_object_to_inventory[o.item()]
395                 for o in self.accessible_object
396             ],
397             device=v.device,
398         )
399
400         return v
401
402     def seq2tilepic(self, t, width):
403         def tile(n):
404             n = n.item()
405             if n in self.id2tile:
406                 return self.id2tile[n]
407             else:
408                 return "?"
409
410         if t.dim() == 2:
411             return [self.seq2tilepic(r, width) for r in t]
412
413         t = t.reshape(-1, width)
414
415         t = ["".join([tile(n) for n in r]) for r in t]
416
417         return t
418
419     def print_worlds(
420         self, src=None, comments=[], width=None, printer=print, ansi_term=False
421     ):
422         if src is None:
423             src = list(self.worlds)
424
425         height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
426
427         def tile(n):
428             n = n.item()
429             if n in self.id2tile:
430                 return self.id2tile[n]
431             else:
432                 return "?"
433
434         for k in range(height):
435
436             def f(x):
437                 if torch.is_tensor(x):
438                     if x.dim() == 0:
439                         x = str(x.item())
440                         return " " * len(x) if k < height - 1 else x
441                     else:
442                         s = "".join([tile(n) for n in x[k]])
443                         if ansi_term:
444                             for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
445                                 (x, 36) for x in "aAbBcC"
446                             ]:
447                                 s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
448                         return s
449                 else:
450                     return " " * len(x) if k < height - 1 else x
451
452             printer("|".join([f(x) for x in src]))
453
454
455 ######################################################################
456
457 if __name__ == "__main__":
458     import os, time, sys
459
460     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
461
462     # nb_agents, nb_iter, display = 10000, 100, False
463     # ansi_term = False
464     nb_agents, nb_iter, display = 4, 10000, True
465     ansi_term = True
466
467     start_time = time.perf_counter()
468     engine = PicroCrafterEngine(
469         world_height=27,
470         world_width=27,
471         nb_walls=35,
472         # world_height=15,
473         # world_width=15,
474         # nb_walls=0,
475         view_height=9,
476         view_width=9,
477         margin=4,
478         device=device,
479     )
480
481     engine.reset(nb_agents)
482
483     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
484
485     start_time = time.perf_counter()
486
487     if ansi_term:
488         coloring = add_ansi_coloring
489     else:
490         coloring = lambda x: x
491
492     stop = 0
493     for k in range(nb_iter):
494         if display:
495             if ansi_term:
496                 to_print = "\u001bc"
497                 # print("\u001b[2J")
498             else:
499                 to_print = ""
500                 os.system("clear")
501
502             l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
503
504             to_print += coloring(fusion_multi_lines(l)) + "\n\n"
505
506         views = engine.views()
507         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
508
509         rewards, inventories, life_levels = engine.step(action)
510
511         if display:
512             l = engine.seq2tilepic(views.flatten(1), engine.view_width)
513             l = [
514                 v + [f"{engine.action2str(a.item())}/{r: 3d}"]
515                 for (v, a, r) in zip(l, action, rewards)
516             ]
517
518             to_print += (
519                 coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
520             )
521
522             print(to_print)
523             sys.stdout.flush()
524             time.sleep(0.25)
525
526         if (life_levels > 0).long().sum() == 0:
527             stop += 1
528             if stop == 10:
529                 break
530
531     print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")