Update.
[pytorch.git] / picocrafter.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@unige.ch> for comments & bug reports        #
18 #########################################################################
19
20 # This is a tiny rogue-like environment implemented with tensor
21 # operations, that runs in batches efficiently on a GPU. On a RTX4090
22 # it can initialize ~20k environments per second and run ~40k
23 # iterations.
24 #
25 # The agent "@" moves in a maze-like grid with random walls "#". There
26 # are five actions: move NESW or do not move.
27 #
28 # There are monsters "$" moving randomly. The agent gets hit by every
29 # monster present in one of the 4 direct neighborhoods at the end of
30 # the moves, each hit results in a rewards of -1.
31 #
32 # The agent starts with 5 life points, each hit costs it 1pt, when it
33 # gets to 0 it dies, gets a reward of -10 and the episode is over. At
34 # every step it recovers 1/20th of a life point, with a maximum of
35 # 5pt.
36 #
37 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
38 # "B", "C"). The keys and vault can only be used in sequence:
39 # initially the agent can move only to free spaces, or to the "a", in
40 # which case the key is removed from the environment and the agent now
41 # carries it, and can move to free spaces or the "A". When it moves to
42 # the "A", it gets a reward, loses the "a", the "A" is removed from
43 # the environment, but can now move to the "b", etc. Rewards are 1 for
44 # "A" and "B" and 10 for "C".
45
46 ######################################################################
47
48 import torch
49
50 from torch.nn.functional import conv2d
51
52 ######################################################################
53
54
55 class PicroCrafterEngine:
56     def __init__(
57         self,
58         world_height=27,
59         world_width=27,
60         nb_walls=27,
61         margin=2,
62         view_height=5,
63         view_width=5,
64         device=torch.device("cpu"),
65     ):
66         assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
67         assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
68
69         self.device = device
70
71         self.world_height = world_height
72         self.world_width = world_width
73         self.margin = margin
74         self.view_height = view_height
75         self.view_width = view_width
76         self.nb_walls = nb_walls
77         self.life_level_max = 5
78         self.life_level_gain_100th = 5
79         self.reward_per_hit = -1
80         self.reward_death = -10
81
82         self.tokens = " +#@$aAbBcC."
83         self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
84         self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
85
86         self.next_object = dict(
87             [
88                 (self.token2id[s], self.token2id[t])
89                 for (s, t) in [
90                     ("a", "A"),
91                     ("A", "b"),
92                     ("b", "B"),
93                     ("B", "c"),
94                     ("c", "C"),
95                     ("C", "."),
96                 ]
97             ]
98         )
99
100         self.object_reward = dict(
101             [
102                 (self.token2id[t], r)
103                 for (t, r) in [
104                     ("a", 0),
105                     ("A", 1),
106                     ("b", 0),
107                     ("B", 1),
108                     ("c", 0),
109                     ("C", 10),
110                 ]
111             ]
112         )
113
114         self.accessible_object_to_inventory = dict(
115             [
116                 (self.token2id[s], self.token2id[t])
117                 for (s, t) in [
118                     ("a", " "),
119                     ("A", "a"),
120                     ("b", " "),
121                     ("B", "b"),
122                     ("c", " "),
123                     ("C", "c"),
124                     (".", " "),
125                 ]
126             ]
127         )
128
129     def reset(self, nb_agents):
130         self.worlds = self.create_worlds(
131             nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
132         ).to(self.device)
133         self.life_level_in_100th = torch.full(
134             (nb_agents,), self.life_level_max * 100, device=self.device
135         )
136         self.accessible_object = torch.full(
137             (nb_agents,), self.token2id["a"], device=self.device
138         )
139
140     def create_mazes(self, nb, height, width, nb_walls):
141         m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
142         m[:, 0, :] = 1
143         m[:, -1, :] = 1
144         m[:, :, 0] = 1
145         m[:, :, -1] = 1
146
147         i = torch.arange(height, device=m.device)[None, :, None]
148         j = torch.arange(width, device=m.device)[None, None, :]
149
150         for _ in range(nb_walls):
151             q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
152                 (1 - m) * (i % 2 == 0) * (j % 2 == 0)
153             ).flatten(1)
154             q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
155             a = q[:, None].expand(-1, 4, -1, -1).clone()
156             a[:, 0, :-1, :] += q[:, 1:, :]
157             a[:, 0, :-2, :] += q[:, 2:, :]
158             a[:, 1, 1:, :] += q[:, :-1, :]
159             a[:, 1, 2:, :] += q[:, :-2, :]
160             a[:, 2, :, :-1] += q[:, :, 1:]
161             a[:, 2, :, :-2] += q[:, :, 2:]
162             a[:, 3, :, 1:] += q[:, :, :-1]
163             a[:, 3, :, 2:] += q[:, :, :-2]
164             a = a[
165                 torch.arange(a.size(0), device=a.device),
166                 torch.randint(4, (a.size(0),), device=a.device),
167             ]
168             m = (m + q + a).clamp(max=1)
169
170         return m
171
172     def create_worlds(self, nb, height, width, nb_walls, margin=2):
173         margin -= 1  # The maze adds a wall all around
174         m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
175         q = m.flatten(1)
176         z = "@aAbBcC$$$$$"  # What to add to the maze
177         u = torch.rand(q.size(), device=q.device) * (1 - q)
178         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
179
180         q *= self.token2id["#"]
181         q[
182             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
183         ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :]
184
185         if margin > 0:
186             r = m.new_full(
187                 (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
188                 self.token2id["+"],
189             )
190             r[:, margin:-margin, margin:-margin] = m
191             m = r
192         return m
193
194     def nb_actions(self):
195         return 5
196
197     def nb_view_tokens(self):
198         return len(self.tokens)
199
200     def min_max_reward(self):
201         return (
202             min(4 * self.reward_per_hit, self.reward_death),
203             max(self.object_reward.values()),
204         )
205
206     def step(self, actions):
207         a = (self.worlds == self.token2id["@"]).nonzero()
208         self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "]
209         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
210         b = a.clone()
211         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
212         # position is empty
213         o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
214         # or it is the next accessible object
215         q = (
216             self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
217         ).long()
218         o = (o + q).clamp(max=1)[:, None]
219         b = (1 - o) * a + o * b
220         self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
221
222         qq = q
223         q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
224         q[b[:, 0]] = qq
225
226         nb_hits = self.monster_moves()
227
228         alive_before = self.life_level_in_100th > 0
229         self.life_level_in_100th[alive_before] = (
230             self.life_level_in_100th[alive_before]
231             + self.life_level_gain_100th
232             - nb_hits[alive_before] * 100
233         ).clamp(max=self.life_level_max * 100)
234         alive_after = self.life_level_in_100th > 0
235         self.worlds[torch.logical_not(alive_after)] = self.token2id["#"]
236         reward = nb_hits * self.reward_per_hit
237
238         for i in range(q.size(0)):
239             if q[i] == 1:
240                 reward[i] += self.object_reward[self.accessible_object[i].item()]
241                 self.accessible_object[i] = self.next_object[
242                     self.accessible_object[i].item()
243                 ]
244
245         reward = (
246             reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death
247         )
248         inventory = torch.tensor(
249             [
250                 self.accessible_object_to_inventory[s.item()]
251                 for s in self.accessible_object
252             ]
253         )
254
255         self.life_level_in_100th = (
256             self.life_level_in_100th
257             * (self.accessible_object != self.token2id["."]).long()
258         )
259
260         reward[torch.logical_not(alive_before)] = 0
261         return reward, inventory, self.life_level_in_100th // 100
262
263     def monster_moves(self):
264         # Current positions of the monsters
265         m = (self.worlds == self.token2id["$"]).long().flatten(1)
266
267         # Total number of monsters
268         n = m.sum(-1).max()
269
270         # Create a tensor with one channel per monster
271         r = (
272             (torch.rand(m.size(), device=m.device) * m)
273             .sort(dim=-1, descending=True)
274             .indices[:, :n]
275         )
276         o = m.new_zeros((m.size(0), n) + m.size()[1:])
277         i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
278         j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
279         o[i, j, r] = 1
280         o = o * m[:, None]
281
282         # Create the tensor of possible motions
283         o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
284         move_kernel = torch.tensor(
285             [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
286         )
287
288         p = (
289             conv2d(
290                 o.view(
291                     o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
292                 ).float(),
293                 move_kernel,
294                 padding=1,
295             ).view(o.size())
296             == 1.0
297         ).long()
298
299         # Let's do the moves per say
300         i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
301             :, None
302         ].expand_as(r)
303
304         for n in range(p.size(1)):
305             u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
306             q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n]
307             r = (
308                 (q * torch.rand(q.size(), device=q.device))
309                 .sort(dim=-1, descending=True)
310                 .indices[:, :1]
311             )
312             self.worlds.flatten(1)[i, u] = self.token2id[" "]
313             self.worlds.flatten(1)[i, r] = self.token2id["$"]
314
315         nb_hits = (
316             (
317                 conv2d(
318                     (self.worlds == self.token2id["$"]).float()[:, None],
319                     move_kernel,
320                     padding=1,
321                 )
322                 .long()
323                 .squeeze(1)
324                 * (self.worlds == self.token2id["@"]).long()
325             )
326             .flatten(1)
327             .sum(-1)
328         )
329
330         return nb_hits
331
332     def views(self):
333         i_height, i_width = (
334             self.view_height - 2 * self.margin,
335             self.view_width - 2 * self.margin,
336         )
337         a = (self.worlds == self.token2id["@"]).nonzero()
338         y = i_height * ((a[:, 1] - self.margin) // i_height)
339         x = i_width * ((a[:, 2] - self.margin) // i_width)
340         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
341         i = (
342             torch.arange(self.view_height, device=a.device)[None, :, None]
343             + y[:, None, None]
344         ).expand_as(n)
345         j = (
346             torch.arange(self.view_width, device=a.device)[None, None, :]
347             + x[:, None, None]
348         ).expand_as(n)
349         v = self.worlds.new_full(
350             (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"]
351         )
352
353         v[a[:, 0]] = self.worlds[n, i, j]
354
355         return v
356
357     def print_worlds(
358         self, src=None, comments=[], width=None, printer=print, ansi_term=False
359     ):
360         if src is None:
361             src = self.worlds
362
363         if width is None:
364             width = src.size(2)
365
366         def token(n):
367             n = n.item()
368             if n in self.id2token:
369                 return self.id2token[n]
370             else:
371                 return "?"
372
373         for k in range(src.size(1)):
374             s = ["".join([token(n) for n in m[k]]) for m in src]
375             s = [r + " " * (width - len(r)) for r in s]
376             if ansi_term:
377
378                 def colorize(x):
379                     for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
380                         (x, 36) for x in "aAbBcC"
381                     ]:
382                         x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m")
383                     return x
384
385                 s = [colorize(x) for x in s]
386             printer(" | ".join(s))
387
388         s = [c + " " * (width - len(c)) for c in comments]
389         printer(" | ".join(s))
390
391
392 ######################################################################
393
394 if __name__ == "__main__":
395     import os, time
396
397     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
398
399     # ansi_term = False
400     # nb_agents, nb_iter, display = 1000, 1000, False
401     nb_agents, nb_iter, display = 3, 10000, True
402     ansi_term = True
403
404     start_time = time.perf_counter()
405     engine = PicroCrafterEngine(
406         world_height=27,
407         world_width=27,
408         nb_walls=35,
409         # world_height=15,
410         # world_width=15,
411         # nb_walls=0,
412         view_height=9,
413         view_width=9,
414         margin=4,
415         device=device,
416     )
417
418     engine.reset(nb_agents)
419
420     print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
421
422     start_time = time.perf_counter()
423
424     stop = 0
425     for k in range(nb_iter):
426         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
427         rewards, inventories, life_levels = engine.step(
428             torch.randint(engine.nb_actions(), (nb_agents,), device=device)
429         )
430
431         if display:
432             os.system("clear")
433             engine.print_worlds(
434                 ansi_term=ansi_term,
435             )
436             print()
437             engine.print_worlds(
438                 src=engine.views(),
439                 comments=[
440                     f"L{p}I{engine.id2token[s.item()]}R{r}"
441                     for p, s, r in zip(life_levels, inventories, rewards)
442                 ],
443                 width=engine.world_width,
444                 ansi_term=ansi_term,
445             )
446             time.sleep(0.25)
447
448         if (life_levels > 0).long().sum() == 0:
449             stop += 1
450             if stop == 2:
451                 break
452
453     print(
454         f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"
455     )