3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 from torch.nn import functional as F
12 ######################################################################
17 nb_lookahead_rewards_codes = 3
20 first_actions_code = first_states_code + nb_states_codes
21 first_rewards_code = first_actions_code + nb_actions_codes
22 first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes
23 nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes
25 ######################################################################
29 return r + first_states_code
33 return r - first_states_code
37 return r + first_actions_code
41 return r - first_actions_code
45 return r + 1 + first_rewards_code
49 return r - first_rewards_code - 1
52 def lookahead_reward2code(r):
53 return r + 1 + first_lookahead_rewards_code
56 def code2lookahead_reward(r):
57 return r - first_lookahead_rewards_code - 1
60 ######################################################################
63 def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3):
64 rnd = torch.rand(nb, height, width)
71 for k in range(nb_walls):
73 rnd.flatten(1).argmax(dim=1)[:, None]
74 == torch.arange(rnd.flatten(1).size(1))[None, :]
75 ).long().reshape(rnd.size())
77 rnd = rnd * (1 - wall.clamp(max=1))
79 states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
81 agent = torch.zeros(states.size(), dtype=torch.int64)
83 agent_actions = torch.randint(5, (nb, T))
84 rewards = torch.zeros(nb, T, dtype=torch.int64)
86 monster = torch.zeros(states.size(), dtype=torch.int64)
87 monster[:, 0, -1, -1] = 1
88 monster_actions = torch.randint(5, (nb, T))
90 all_moves = agent.new(nb, 5, height, width)
91 for t in range(T - 1):
93 all_moves[:, 0] = agent[:, t]
94 all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
95 all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
96 all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
97 all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
98 a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
99 after_move = (all_moves * a).sum(dim=1)
101 (after_move * (1 - wall) * (1 - monster[:, t]))
103 .sum(dim=1)[:, None, None]
106 agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
109 all_moves[:, 0] = monster[:, t]
110 all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
111 all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
112 all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
113 all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
114 a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
115 after_move = (all_moves * a).sum(dim=1)
117 (after_move * (1 - wall) * (1 - agent[:, t + 1]))
119 .sum(dim=1)[:, None, None]
122 monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
125 (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
126 + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
127 + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
128 + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
130 hit = (hit > 0).long()
132 # assert hit.min() == 0 and hit.max() <= 1
134 rewards[:, t + 1] = -hit + (1 - hit) * agent[:, t + 1, -1, -1]
136 states += 2 * agent + 3 * monster
138 return states, agent_actions, rewards
141 ######################################################################
144 def episodes2seq(states, actions, rewards):
145 neg = rewards.new_zeros(rewards.size())
146 pos = rewards.new_zeros(rewards.size())
147 for t in range(neg.size(1) - 1):
148 neg[:, t] = rewards[:, t:].min(dim=-1).values
149 pos[:, t] = rewards[:, t:].max(dim=-1).values
150 s = (neg < 0).long() * neg + (neg >= 0).long() * pos
154 lookahead_reward2code(s[:, :, None]),
155 state2code(states.flatten(2)),
156 action2code(actions[:, :, None]),
157 reward2code(rewards[:, :, None]),
163 def seq2episodes(seq, height, width):
164 seq = seq.reshape(seq.size(0), -1, height * width + 3)
165 lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
166 states = code2state(seq[:, :, 1 : height * width + 1])
167 states = states.reshape(states.size(0), states.size(1), height, width)
168 actions = code2action(seq[:, :, height * width + 1])
169 rewards = code2reward(seq[:, :, height * width + 2])
170 return lookahead_rewards, states, actions, rewards
175 if t >= first_states_code and t < first_states_code + nb_states_codes:
176 return " #@$"[t - first_states_code]
177 elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
178 return "ISNEW"[t - first_actions_code]
179 elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
180 return "-0+"[t - first_rewards_code]
182 t >= first_lookahead_rewards_code
183 and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
185 return "n.p"[t - first_lookahead_rewards_code]
189 return ["".join([token2str(x.item()) for x in row]) for row in seq]
192 ######################################################################
196 lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False
200 # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
201 vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
204 vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
206 hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
210 for n in range(states.size(0)):
214 return "?" if v < 0 or v >= len(symbols) else symbols[v]
216 for i in range(states.size(2)):
220 ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]]
226 # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
228 def status_bar(a, r, lr=None):
229 a, r = a.item(), r.item()
230 sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
231 sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
236 sb_lr = "n p"[lr + 1] if lr in {-1, 0, 1} else "?"
241 + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
250 for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n])
260 for u, c in [("$", 31), ("@", 32)]:
261 result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
266 ######################################################################
268 if __name__ == "__main__":
269 nb, height, width, T, nb_walls = 5, 5, 7, 4, 5
270 states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
271 seq = episodes2seq(states, actions, rewards)
272 lr, s, a, r = seq2episodes(seq, height, width)
273 print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
275 for s in seq2str(seq):