X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picocrafter.py;h=23d93b2cdd0d2c9a52bd76a6289dcf0e2d5df135;hb=HEAD;hp=33a00c126e5309924933719fcf0a32bdcdffaf2a;hpb=a2ccdd2f5e9fb3e7ed52492729b880f815ddfbcb;p=pytorch.git diff --git a/picocrafter.py b/picocrafter.py index 33a00c1..23d93b2 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -22,8 +22,9 @@ # it can initialize ~20k environments per second and run ~40k # iterations. # -# The agent "@" moves in a maze-like grid with random walls "#". There -# are five actions: move NESW or do not move. +# The environment is a rectangular area with walls "#" dispatched +# randomly. The agent "@" can perform five actions: move "NESW" or be +# immobile "I". # # There are monsters "$" moving randomly. The agent gets hit by every # monster present in one of the 4 direct neighborhoods at the end of @@ -35,11 +36,14 @@ # 5pt. # # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A", -# "B", "C"). They keys can only be used in sequence: initially the -# agent can move only to free spaces, or to the "a", in which case it -# now carries it, and can move to free spaces or the "A". When it -# moves to the "A", it gets a reward and loses the "a", but can now -# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C". +# "B", "C"). The keys and vault can only be used in sequence: +# initially the agent can move only to free spaces, or to the "a", in +# which case the key is removed from the environment and the agent now +# carries it, it appears in the inventory at the bottom of the frame, +# and the agent can now move to free spaces or the "A". When it moves +# to the "A", it gets a reward, loses the "a", the "A" is removed from +# the environment, but the agent can now move to the "b", etc. Rewards +# are 1 for "A" and "B" and 10 for "C". ###################################################################### @@ -50,25 +54,64 @@ from torch.nn.functional import conv2d ###################################################################### -class PicroCrafterEngine: +def to_ansi(s): + if type(s) is list: + return [to_ansi(x) for x in s] + + for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]: + s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m") + + return s + + +def to_unicode(s): + if type(s) is list: + return [to_unicode(x) for x in s] + + for u, c in [("#", "█"), ("+", "░"), ("|", "│")]: + s = s.replace(u, c) + + return s + + +def fusion_multi_lines(l, width_min=0): + l = [x if type(x) is str else str(x) for x in l] + + l = [x.split("\n") for x in l] + + def center(r, w): + k = w - len(r) + return " " * (k // 2) + r + " " * (k - k // 2) + + def f(o, h): + w = max(width_min, max([len(r) for r in o])) + return [" " * w] * (h - len(o)) + [center(r, w) for r in o] + + h = max([len(x) for x in l]) + l = [f(o, h) for o in l] + + return "\n".join(["|".join([o[k] for o in l]) for k in range(h)]) + + +class PicroCrafterEnvironment: def __init__( self, world_height=27, world_width=27, nb_walls=27, - margin=2, + world_margin=2, view_height=5, view_width=5, device=torch.device("cpu"), ): - assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0 - assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0 + assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0 + assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0 self.device = device self.world_height = world_height self.world_width = world_width - self.margin = margin + self.world_margin = world_margin self.view_height = view_height self.view_width = view_width self.nb_walls = nb_walls @@ -77,26 +120,29 @@ class PicroCrafterEngine: self.reward_per_hit = -1 self.reward_death = -10 - self.tokens = " +#@$aAbBcC" - self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)]) - self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)]) + self.tiles = " +#@$aAbBcC-" + "".join( + [str(n) for n in range(self.life_level_max + 1)] + ) + self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)]) + self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)]) self.next_object = dict( [ - (self.token2id[s], self.token2id[t]) + (self.tile2id[s], self.tile2id[t]) for (s, t) in [ ("a", "A"), ("A", "b"), ("b", "B"), ("B", "c"), ("c", "C"), + ("C", "-"), ] ] ) self.object_reward = dict( [ - (self.token2id[t], r) + (self.tile2id[t], r) for (t, r) in [ ("a", 0), ("A", 1), @@ -108,29 +154,34 @@ class PicroCrafterEngine: ] ) - self.acessible_object_to_inventory = dict( + self.accessible_object_to_inventory = dict( [ - (self.token2id[s], self.token2id[t]) + (self.tile2id[s], self.tile2id[t]) for (s, t) in [ ("a", " "), ("A", "a"), ("b", " "), ("B", "b"), ("c", " "), - ("C", " "), + ("C", "c"), + ("-", " "), ] ] ) def reset(self, nb_agents): self.worlds = self.create_worlds( - nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin + nb_agents, + self.world_height, + self.world_width, + self.nb_walls, + self.world_margin, ).to(self.device) self.life_level_in_100th = torch.full( - (nb_agents,), self.life_level_max * 100, device=self.device + (nb_agents,), self.life_level_max * 100 + 99, device=self.device ) self.accessible_object = torch.full( - (nb_agents,), self.token2id["a"], device=self.device + (nb_agents,), self.tile2id["a"], device=self.device ) def create_mazes(self, nb, height, width, nb_walls): @@ -165,33 +216,41 @@ class PicroCrafterEngine: return m - def create_worlds(self, nb, height, width, nb_walls, margin=2): - margin -= 1 # The maze adds a wall all around - m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls) + def create_worlds(self, nb, height, width, nb_walls, world_margin=2): + world_margin -= 1 # The maze adds a wall all around + m = self.create_mazes( + nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls + ) q = m.flatten(1) z = "@aAbBcC$$$$$" # What to add to the maze u = torch.rand(q.size(), device=q.device) * (1 - q) r = u.sort(dim=-1, descending=True).indices[:, : len(z)] - q *= self.token2id["#"] + q *= self.tile2id["#"] q[ torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r - ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :] + ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] - if margin > 0: + if world_margin > 0: r = m.new_full( - (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2), - self.token2id["+"], + (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2), + self.tile2id["+"], ) - r[:, margin:-margin, margin:-margin] = m + r[:, world_margin:-world_margin, world_margin:-world_margin] = m m = r return m def nb_actions(self): return 5 - def nb_view_tokens(self): - return len(self.tokens) + def action2str(self, n): + if n >= 0 and n < 5: + return "INESW"[n] + else: + return "?" + + def nb_state_token_values(self): + return len(self.tiles) def min_max_reward(self): return ( @@ -200,32 +259,39 @@ class PicroCrafterEngine: ) def step(self, actions): - a = (self.worlds == self.token2id["@"]).nonzero() - self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "] + a = (self.worlds == self.tile2id["@"]).nonzero() + self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "] s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device) b = a.clone() b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]] - # position is empty - o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long() + o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long() # or it is the next accessible object q = ( self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]] ).long() o = (o + q).clamp(max=1)[:, None] b = (1 - o) * a + o * b - self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"] + self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"] + + qq = q + q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:]) + q[b[:, 0]] = qq nb_hits = self.monster_moves() - alive_before = self.life_level_in_100th > 0 + alive_before = self.life_level_in_100th >= 100 + self.life_level_in_100th[alive_before] = ( self.life_level_in_100th[alive_before] + self.life_level_gain_100th - nb_hits[alive_before] * 100 - ).clamp(max=self.life_level_max * 100) - alive_after = self.life_level_in_100th > 0 - self.worlds[torch.logical_not(alive_after)] = self.token2id["#"] + ).clamp(max=self.life_level_max * 100 + 99) + + alive_after = self.life_level_in_100th >= 100 + + self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"] + reward = nb_hits * self.reward_per_hit for i in range(q.size(0)): @@ -236,21 +302,28 @@ class PicroCrafterEngine: ] reward = ( - reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death + alive_after.long() * reward + + alive_before.long() * (1 - alive_after.long()) * self.reward_death ) inventory = torch.tensor( [ - self.acessible_object_to_inventory[s.item()] + self.accessible_object_to_inventory[s.item()] for s in self.accessible_object ] ) + self.life_level_in_100th = ( + self.life_level_in_100th + * (self.accessible_object != self.tile2id["-"]).long() + ) + reward[torch.logical_not(alive_before)] = 0 + return reward, inventory, self.life_level_in_100th // 100 def monster_moves(self): # Current positions of the monsters - m = (self.worlds == self.token2id["$"]).long().flatten(1) + m = (self.worlds == self.tile2id["$"]).long().flatten(1) # Total number of monsters n = m.sum(-1).max() @@ -291,25 +364,25 @@ class PicroCrafterEngine: for n in range(p.size(1)): u = o[:, n].sort(dim=-1, descending=True).indices[:, :1] - q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n] + q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n] r = ( (q * torch.rand(q.size(), device=q.device)) .sort(dim=-1, descending=True) .indices[:, :1] ) - self.worlds.flatten(1)[i, u] = self.token2id[" "] - self.worlds.flatten(1)[i, r] = self.token2id["$"] + self.worlds.flatten(1)[i, u] = self.tile2id[" "] + self.worlds.flatten(1)[i, r] = self.tile2id["$"] nb_hits = ( ( conv2d( - (self.worlds == self.token2id["$"]).float()[:, None], + (self.worlds == self.tile2id["$"]).float()[:, None], move_kernel, padding=1, ) .long() .squeeze(1) - * (self.worlds == self.token2id["@"]).long() + * (self.worlds == self.tile2id["@"]).long() ) .flatten(1) .sum(-1) @@ -317,14 +390,17 @@ class PicroCrafterEngine: return nb_hits - def views(self): + def state_size(self): + return (self.view_height + 1) * self.view_width + + def state(self): i_height, i_width = ( - self.view_height - 2 * self.margin, - self.view_width - 2 * self.margin, + self.view_height - 2 * self.world_margin, + self.view_width - 2 * self.world_margin, ) - a = (self.worlds == self.token2id["@"]).nonzero() - y = i_height * ((a[:, 1] - self.margin) // i_height) - x = i_width * ((a[:, 2] - self.margin) // i_width) + a = (self.worlds == self.tile2id["@"]).nonzero() + y = i_height * ((a[:, 1] - self.world_margin) // i_height) + x = i_width * ((a[:, 2] - self.world_margin) // i_width) n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width) i = ( torch.arange(self.view_height, device=a.device)[None, :, None] @@ -335,103 +411,125 @@ class PicroCrafterEngine: + x[:, None, None] ).expand_as(n) v = self.worlds.new_full( - (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"] + (self.worlds.size(0), self.view_height + 1, self.view_width), + self.tile2id["#"], ) - v[a[:, 0]] = self.worlds[n, i, j] + v[a[:, 0], : self.view_height] = self.worlds[n, i, j] - return v - - def print_worlds( - self, src=None, comments=[], width=None, printer=print, ansi_term=False - ): - if src is None: - src = self.worlds + v[:, self.view_height] = self.tile2id["-"] + v[:, self.view_height, 0] = self.tile2id["0"] + ( + self.life_level_in_100th // 100 + ).clamp(min=0, max=self.life_level_max) + v[:, self.view_height, 1] = torch.tensor( + [ + self.accessible_object_to_inventory[o.item()] + for o in self.accessible_object + ], + device=v.device, + ) - if width is None: - width = src.size(2) + return v.flatten(1), self.life_level_in_100th >= 100 - def token(n): + def state2str(self, t, width=None): + def tile(n): n = n.item() - if n in self.id2token: - return self.id2token[n] + if n in self.id2tile: + return self.id2tile[n] else: return "?" - for k in range(src.size(1)): - s = ["".join([token(n) for n in m[k]]) for m in src] - s = [r + " " * (width - len(r)) for r in s] - if ansi_term: + if t.dim() == 2: + return [self.state2str(r, width) for r in t] + + if width is None: + width = self.view_width - def colorize(x): - for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ - (x, 36) for x in "aAbBcC" - ]: - x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m") - return x + t = t.reshape(-1, width) - s = [colorize(x) for x in s] - printer(" | ".join(s)) + t = "\n".join(["".join([tile(n) for n in r]) for r in t]) - s = [c + " " * (width - len(c)) for c in comments] - printer(" | ".join(s)) + return t ###################################################################### if __name__ == "__main__": - import os, time + import os, time, sys device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ansi_term = False - # nb_agents, nb_iter, display = 1000, 100, False - nb_agents, nb_iter, display = 3, 10000, True + # char_conv = lambda x: x + char_conv = to_unicode + + # nb_agents, nb_iter, display = 1000, 1000, False + # ansi_term = False + + nb_agents, nb_iter, display = 4, 10000, True ansi_term = True + if ansi_term: + char_conv = lambda x: to_ansi(to_unicode(x)) + start_time = time.perf_counter() - engine = PicroCrafterEngine( + environment = PicroCrafterEnvironment( world_height=27, world_width=27, nb_walls=35, view_height=9, view_width=9, - margin=4, + world_margin=4, device=device, ) - engine.reset(nb_agents) + environment.reset(nb_agents) print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s") start_time = time.perf_counter() + stop = 0 for k in range(nb_iter): - action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) - rewards, inventories, life_levels = engine.step( - torch.randint(engine.nb_actions(), (nb_agents,), device=device) + if display: + if ansi_term: + to_print = "\u001bc" + # print("\u001b[2J") + else: + to_print = "" + os.system("clear") + + l = environment.state2str( + environment.worlds.flatten(1), width=environment.world_width + ) + + to_print += char_conv(fusion_multi_lines(l)) + "\n\n" + + state, alive = environment.state() + action = alive * torch.randint( + environment.nb_actions(), (nb_agents,), device=device ) + rewards, inventories, life_levels = environment.step(action) + if display: - os.system("clear") - engine.print_worlds( - ansi_term=ansi_term, - ) - print() - engine.print_worlds( - src=engine.views(), - comments=[ - f"L{p}I{engine.id2token[s.item()]}R{r}" - for p, s, r in zip(life_levels, inventories, rewards) - ], - width=engine.world_width, - ansi_term=ansi_term, + l = environment.state2str(state) + l = [ + v + f"\n{environment.action2str(a.item())}/{r: 3d}" + for (v, a, r) in zip(l, action, rewards) + ] + + to_print += ( + char_conv(fusion_multi_lines(l, width_min=environment.world_width)) + + "\n" ) - time.sleep(0.5) + + print(to_print) + sys.stdout.flush() + time.sleep(0.25) if (life_levels > 0).long().sum() == 0: - break + stop += 1 + if stop == 10: + break - print( - f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s" - ) + print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")