Update.
[pytorch.git] / picocrafter.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@unige.ch> for comments & bug reports        #
18 #########################################################################
19
20 # This is a tiny rogue-like environment implemented with tensor
21 # operations, that runs in batches efficiently on a GPU. On a RTX4090
22 # it can initialize ~20k environments per second and run ~40k
23 # iterations.
24 #
25 # The environment is a rectangular area with walls "#" dispatched
26 # randomly. The agent "@" can perform five actions: move NESW or do
27 # not move.
28 #
29 # There are monsters "$" moving randomly. The agent gets hit by every
30 # monster present in one of the 4 direct neighborhoods at the end of
31 # the moves, each hit results in a rewards of -1.
32 #
33 # The agent starts with 5 life points, each hit costs it 1pt, when it
34 # gets to 0 it dies, gets a reward of -10 and the episode is over. At
35 # every step it recovers 1/20th of a life point, with a maximum of
36 # 5pt.
37 #
38 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
39 # "B", "C"). The keys and vault can only be used in sequence:
40 # initially the agent can move only to free spaces, or to the "a", in
41 # which case the key is removed from the environment and the agent now
42 # carries it, and can move to free spaces or the "A". When it moves to
43 # the "A", it gets a reward, loses the "a", the "A" is removed from
44 # the environment, but the agent can now move to the "b", etc. Rewards
45 # are 1 for "A" and "B" and 10 for "C".
46
47 ######################################################################
48
49 import torch
50
51 from torch.nn.functional import conv2d
52
53 ######################################################################
54
55
56 def to_ansi(s):
57     if type(s) is list:
58         return [to_ansi(x) for x in s]
59
60     for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
61         s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
62
63     return s
64
65
66 def to_unicode(s):
67     if type(s) is list:
68         return [to_unicode(x) for x in s]
69
70     for u, c in [("#", "█"), ("+", "░"), ("|", "│")]:
71         s = s.replace(u, c)
72
73     return s
74
75
76 def fusion_multi_lines(l, width_min=0):
77     l = [x if type(x) is str else str(x) for x in l]
78
79     l = [x.split("\n") for x in l]
80
81     def center(r, w):
82         k = w - len(r)
83         return " " * (k // 2) + r + " " * (k - k // 2)
84
85     def f(o, h):
86         w = max(width_min, max([len(r) for r in o]))
87         return [" " * w] * (h - len(o)) + [center(r, w) for r in o]
88
89     h = max([len(x) for x in l])
90     l = [f(o, h) for o in l]
91
92     return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
93
94
95 class PicroCrafterEnvironment:
96     def __init__(
97         self,
98         world_height=27,
99         world_width=27,
100         nb_walls=27,
101         world_margin=2,
102         view_height=5,
103         view_width=5,
104         device=torch.device("cpu"),
105     ):
106         assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0
107         assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0
108
109         self.device = device
110
111         self.world_height = world_height
112         self.world_width = world_width
113         self.world_margin = world_margin
114         self.view_height = view_height
115         self.view_width = view_width
116         self.nb_walls = nb_walls
117         self.life_level_max = 5
118         self.life_level_gain_100th = 5
119         self.reward_per_hit = -1
120         self.reward_death = -10
121
122         self.tiles = " +#@$aAbBcC-" + "".join(
123             [str(n) for n in range(self.life_level_max + 1)]
124         )
125         self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)])
126         self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)])
127
128         self.next_object = dict(
129             [
130                 (self.tile2id[s], self.tile2id[t])
131                 for (s, t) in [
132                     ("a", "A"),
133                     ("A", "b"),
134                     ("b", "B"),
135                     ("B", "c"),
136                     ("c", "C"),
137                     ("C", "-"),
138                 ]
139             ]
140         )
141
142         self.object_reward = dict(
143             [
144                 (self.tile2id[t], r)
145                 for (t, r) in [
146                     ("a", 0),
147                     ("A", 1),
148                     ("b", 0),
149                     ("B", 1),
150                     ("c", 0),
151                     ("C", 10),
152                 ]
153             ]
154         )
155
156         self.accessible_object_to_inventory = dict(
157             [
158                 (self.tile2id[s], self.tile2id[t])
159                 for (s, t) in [
160                     ("a", " "),
161                     ("A", "a"),
162                     ("b", " "),
163                     ("B", "b"),
164                     ("c", " "),
165                     ("C", "c"),
166                     ("-", " "),
167                 ]
168             ]
169         )
170
171     def reset(self, nb_agents):
172         self.worlds = self.create_worlds(
173             nb_agents,
174             self.world_height,
175             self.world_width,
176             self.nb_walls,
177             self.world_margin,
178         ).to(self.device)
179         self.life_level_in_100th = torch.full(
180             (nb_agents,), self.life_level_max * 100 + 99, device=self.device
181         )
182         self.accessible_object = torch.full(
183             (nb_agents,), self.tile2id["a"], device=self.device
184         )
185
186     def create_mazes(self, nb, height, width, nb_walls):
187         m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
188         m[:, 0, :] = 1
189         m[:, -1, :] = 1
190         m[:, :, 0] = 1
191         m[:, :, -1] = 1
192
193         i = torch.arange(height, device=m.device)[None, :, None]
194         j = torch.arange(width, device=m.device)[None, None, :]
195
196         for _ in range(nb_walls):
197             q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
198                 (1 - m) * (i % 2 == 0) * (j % 2 == 0)
199             ).flatten(1)
200             q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
201             a = q[:, None].expand(-1, 4, -1, -1).clone()
202             a[:, 0, :-1, :] += q[:, 1:, :]
203             a[:, 0, :-2, :] += q[:, 2:, :]
204             a[:, 1, 1:, :] += q[:, :-1, :]
205             a[:, 1, 2:, :] += q[:, :-2, :]
206             a[:, 2, :, :-1] += q[:, :, 1:]
207             a[:, 2, :, :-2] += q[:, :, 2:]
208             a[:, 3, :, 1:] += q[:, :, :-1]
209             a[:, 3, :, 2:] += q[:, :, :-2]
210             a = a[
211                 torch.arange(a.size(0), device=a.device),
212                 torch.randint(4, (a.size(0),), device=a.device),
213             ]
214             m = (m + q + a).clamp(max=1)
215
216         return m
217
218     def create_worlds(self, nb, height, width, nb_walls, world_margin=2):
219         world_margin -= 1  # The maze adds a wall all around
220         m = self.create_mazes(
221             nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls
222         )
223         q = m.flatten(1)
224         z = "@aAbBcC$$$$$"  # What to add to the maze
225         u = torch.rand(q.size(), device=q.device) * (1 - q)
226         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
227
228         q *= self.tile2id["#"]
229         q[
230             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
231         ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
232
233         if world_margin > 0:
234             r = m.new_full(
235                 (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2),
236                 self.tile2id["+"],
237             )
238             r[:, world_margin:-world_margin, world_margin:-world_margin] = m
239             m = r
240         return m
241
242     def nb_actions(self):
243         return 5
244
245     def action2str(self, n):
246         if n >= 0 and n < 5:
247             return "XNESW"[n]
248         else:
249             return "?"
250
251     def nb_state_token_values(self):
252         return len(self.tiles)
253
254     def min_max_reward(self):
255         return (
256             min(4 * self.reward_per_hit, self.reward_death),
257             max(self.object_reward.values()),
258         )
259
260     def step(self, actions):
261         a = (self.worlds == self.tile2id["@"]).nonzero()
262         self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "]
263         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
264         b = a.clone()
265         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
266         # position is empty
267         o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long()
268         # or it is the next accessible object
269         q = (
270             self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
271         ).long()
272         o = (o + q).clamp(max=1)[:, None]
273         b = (1 - o) * a + o * b
274         self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"]
275
276         qq = q
277         q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
278         q[b[:, 0]] = qq
279
280         nb_hits = self.monster_moves()
281
282         alive_before = self.life_level_in_100th >= 100
283
284         self.life_level_in_100th[alive_before] = (
285             self.life_level_in_100th[alive_before]
286             + self.life_level_gain_100th
287             - nb_hits[alive_before] * 100
288         ).clamp(max=self.life_level_max * 100 + 99)
289
290         alive_after = self.life_level_in_100th >= 100
291
292         self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
293
294         reward = nb_hits * self.reward_per_hit
295
296         for i in range(q.size(0)):
297             if q[i] == 1:
298                 reward[i] += self.object_reward[self.accessible_object[i].item()]
299                 self.accessible_object[i] = self.next_object[
300                     self.accessible_object[i].item()
301                 ]
302
303         reward = (
304             alive_after.long() * reward
305             + alive_before.long() * (1 - alive_after.long()) * self.reward_death
306         )
307         inventory = torch.tensor(
308             [
309                 self.accessible_object_to_inventory[s.item()]
310                 for s in self.accessible_object
311             ]
312         )
313
314         self.life_level_in_100th = (
315             self.life_level_in_100th
316             * (self.accessible_object != self.tile2id["-"]).long()
317         )
318
319         reward[torch.logical_not(alive_before)] = 0
320
321         return reward, inventory, self.life_level_in_100th // 100
322
323     def monster_moves(self):
324         # Current positions of the monsters
325         m = (self.worlds == self.tile2id["$"]).long().flatten(1)
326
327         # Total number of monsters
328         n = m.sum(-1).max()
329
330         # Create a tensor with one channel per monster
331         r = (
332             (torch.rand(m.size(), device=m.device) * m)
333             .sort(dim=-1, descending=True)
334             .indices[:, :n]
335         )
336         o = m.new_zeros((m.size(0), n) + m.size()[1:])
337         i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
338         j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
339         o[i, j, r] = 1
340         o = o * m[:, None]
341
342         # Create the tensor of possible motions
343         o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
344         move_kernel = torch.tensor(
345             [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
346         )
347
348         p = (
349             conv2d(
350                 o.view(
351                     o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
352                 ).float(),
353                 move_kernel,
354                 padding=1,
355             ).view(o.size())
356             == 1.0
357         ).long()
358
359         # Let's do the moves per say
360         i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
361             :, None
362         ].expand_as(r)
363
364         for n in range(p.size(1)):
365             u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
366             q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n]
367             r = (
368                 (q * torch.rand(q.size(), device=q.device))
369                 .sort(dim=-1, descending=True)
370                 .indices[:, :1]
371             )
372             self.worlds.flatten(1)[i, u] = self.tile2id[" "]
373             self.worlds.flatten(1)[i, r] = self.tile2id["$"]
374
375         nb_hits = (
376             (
377                 conv2d(
378                     (self.worlds == self.tile2id["$"]).float()[:, None],
379                     move_kernel,
380                     padding=1,
381                 )
382                 .long()
383                 .squeeze(1)
384                 * (self.worlds == self.tile2id["@"]).long()
385             )
386             .flatten(1)
387             .sum(-1)
388         )
389
390         return nb_hits
391
392     def state_size(self):
393         return (self.view_height + 1) * self.view_width
394
395     def state(self):
396         i_height, i_width = (
397             self.view_height - 2 * self.world_margin,
398             self.view_width - 2 * self.world_margin,
399         )
400         a = (self.worlds == self.tile2id["@"]).nonzero()
401         y = i_height * ((a[:, 1] - self.world_margin) // i_height)
402         x = i_width * ((a[:, 2] - self.world_margin) // i_width)
403         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
404         i = (
405             torch.arange(self.view_height, device=a.device)[None, :, None]
406             + y[:, None, None]
407         ).expand_as(n)
408         j = (
409             torch.arange(self.view_width, device=a.device)[None, None, :]
410             + x[:, None, None]
411         ).expand_as(n)
412         v = self.worlds.new_full(
413             (self.worlds.size(0), self.view_height + 1, self.view_width),
414             self.tile2id["#"],
415         )
416
417         v[a[:, 0], : self.view_height] = self.worlds[n, i, j]
418
419         v[:, self.view_height] = self.tile2id["-"]
420         v[:, self.view_height, 0] = self.tile2id["0"] + (
421             self.life_level_in_100th // 100
422         ).clamp(min=0, max=self.life_level_max)
423         v[:, self.view_height, 1] = torch.tensor(
424             [
425                 self.accessible_object_to_inventory[o.item()]
426                 for o in self.accessible_object
427             ],
428             device=v.device,
429         )
430
431         return v.flatten(1), self.life_level_in_100th >= 100
432
433     def state2str(self, t, width=None):
434         def tile(n):
435             n = n.item()
436             if n in self.id2tile:
437                 return self.id2tile[n]
438             else:
439                 return "?"
440
441         if t.dim() == 2:
442             return [self.state2str(r, width) for r in t]
443
444         if width is None:
445             width = self.view_width
446
447         t = t.reshape(-1, width)
448
449         t = "\n".join(["".join([tile(n) for n in r]) for r in t])
450
451         return t
452
453
454 ######################################################################
455
456 if __name__ == "__main__":
457     import os, time, sys
458
459     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
460
461     # char_conv = lambda x: x
462     char_conv = to_unicode
463
464     # nb_agents, nb_iter, display = 1000, 1000, False
465     # ansi_term = False
466
467     nb_agents, nb_iter, display = 4, 10000, True
468     ansi_term = True
469
470     if ansi_term:
471         char_conv = lambda x: to_ansi(to_unicode(x))
472
473     start_time = time.perf_counter()
474     environment = PicroCrafterEnvironment(
475         world_height=27,
476         world_width=27,
477         nb_walls=35,
478         view_height=9,
479         view_width=9,
480         world_margin=4,
481         device=device,
482     )
483
484     environment.reset(nb_agents)
485
486     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
487
488     start_time = time.perf_counter()
489
490     stop = 0
491     for k in range(nb_iter):
492         if display:
493             if ansi_term:
494                 to_print = "\u001bc"
495                 # print("\u001b[2J")
496             else:
497                 to_print = ""
498                 os.system("clear")
499
500             l = environment.state2str(
501                 environment.worlds.flatten(1), width=environment.world_width
502             )
503
504             to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
505
506         state, alive = environment.state()
507         action = alive * torch.randint(
508             environment.nb_actions(), (nb_agents,), device=device
509         )
510
511         rewards, inventories, life_levels = environment.step(action)
512
513         if display:
514             l = environment.state2str(state)
515             l = [
516                 v + f"\n{environment.action2str(a.item())}/{r: 3d}"
517                 for (v, a, r) in zip(l, action, rewards)
518             ]
519
520             to_print += (
521                 char_conv(fusion_multi_lines(l, width_min=environment.world_width))
522                 + "\n"
523             )
524
525             print(to_print)
526             sys.stdout.flush()
527             time.sleep(0.25)
528
529         if (life_levels > 0).long().sum() == 0:
530             stop += 1
531             if stop == 10:
532                 break
533
534     print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")