3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 from torch.nn import functional as F
12 ######################################################################
17 nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN
20 first_actions_code = first_states_code + nb_states_codes
21 first_rewards_code = first_actions_code + nb_actions_codes
22 first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes
23 nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes
25 ######################################################################
29 return r + first_states_code
33 return r - first_states_code
37 return r + first_actions_code
41 return r - first_actions_code
45 return r + 1 + first_rewards_code
49 return r - first_rewards_code - 1
52 def lookahead_reward2code(r):
53 # -1, 0, +1 or 2 for UNKNOWN
54 return r + 1 + first_lookahead_rewards_code
57 def code2lookahead_reward(r):
58 return r - first_lookahead_rewards_code - 1
61 ######################################################################
64 def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
65 rnd = torch.rand(nb, height, width)
71 for k in range(nb_walls):
73 rnd.flatten(1).argmax(dim=1)[:, None]
74 == torch.arange(rnd.flatten(1).size(1))[None, :]
75 ).long().reshape(rnd.size())
77 rnd = rnd * (1 - wall.clamp(max=1))
79 rnd = torch.rand(nb, height, width)
80 coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
81 rnd = rnd * (1 - wall.clamp(max=1))
82 for k in range(nb_coins):
83 coins[:, 0] = coins[:, 0] + (
84 rnd.flatten(1).argmax(dim=1)[:, None]
85 == torch.arange(rnd.flatten(1).size(1))[None, :]
86 ).long().reshape(rnd.size())
88 rnd = rnd * (1 - coins[:, 0].clamp(max=1))
90 states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
92 agent = torch.zeros(states.size(), dtype=torch.int64)
94 agent_actions = torch.randint(5, (nb, T))
95 rewards = torch.zeros(nb, T, dtype=torch.int64)
97 troll = torch.zeros(states.size(), dtype=torch.int64)
98 troll[:, 0, -1, -1] = 1
99 troll_actions = torch.randint(5, (nb, T))
101 all_moves = agent.new(nb, 5, height, width)
102 for t in range(T - 1):
104 all_moves[:, 0] = agent[:, t]
105 all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
106 all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
107 all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
108 all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
109 a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
110 after_move = (all_moves * a).sum(dim=1)
112 (after_move * (1 - wall) * (1 - troll[:, t]))
114 .sum(dim=1)[:, None, None]
117 agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
120 all_moves[:, 0] = troll[:, t]
121 all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
122 all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
123 all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
124 all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
125 a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
126 after_move = (all_moves * a).sum(dim=1)
128 (after_move * (1 - wall) * (1 - agent[:, t + 1]))
130 .sum(dim=1)[:, None, None]
133 troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
136 (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
137 + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
138 + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
139 + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
141 hit = (hit > 0).long()
143 # assert hit.min() == 0 and hit.max() <= 1
145 got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
146 coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
148 rewards[:, t + 1] = -hit + (1 - hit) * got_coin
150 states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
152 return states, agent_actions, rewards
155 ######################################################################
158 def episodes2seq(states, actions, rewards):
159 neg = rewards.new_zeros(rewards.size())
160 pos = rewards.new_zeros(rewards.size())
161 for t in range(neg.size(1) - 1):
162 neg[:, t] = rewards[:, t:].min(dim=-1).values
163 pos[:, t] = rewards[:, t:].max(dim=-1).values
164 s = (neg < 0).long() * neg + (neg >= 0).long() * pos
168 lookahead_reward2code(s[:, :, None]),
169 state2code(states.flatten(2)),
170 action2code(actions[:, :, None]),
171 reward2code(rewards[:, :, None]),
177 def seq2episodes(seq, height, width):
178 seq = seq.reshape(seq.size(0), -1, height * width + 3)
179 lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
180 states = code2state(seq[:, :, 1 : height * width + 1])
181 states = states.reshape(states.size(0), states.size(1), height, width)
182 actions = code2action(seq[:, :, height * width + 1])
183 rewards = code2reward(seq[:, :, height * width + 2])
184 return lookahead_rewards, states, actions, rewards
189 if t >= first_states_code and t < first_states_code + nb_states_codes:
190 return " #@T$"[t - first_states_code]
191 elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
192 return "ISNEW"[t - first_actions_code]
193 elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
194 return "-0+"[t - first_rewards_code]
196 t >= first_lookahead_rewards_code
197 and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
199 return "n.pU"[t - first_lookahead_rewards_code]
203 return ["".join([token2str(x.item()) for x in row]) for row in seq]
206 ######################################################################
210 lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False
214 # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
215 vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
218 vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
220 hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
224 for n in range(states.size(0)):
228 return "?" if v < 0 or v >= len(symbols) else symbols[v]
230 for i in range(states.size(2)):
234 ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]]
240 # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
242 def status_bar(a, r, lr=None):
243 a, r = a.item(), r.item()
244 sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
245 sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
250 sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
255 + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
264 for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n])
274 for u, c in [("T", 31), ("@", 32), ("$", 34)]:
275 result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
280 ######################################################################
283 def save_seq_as_anim_script(seq, filename):
284 it_len = height * width + 3
287 seq.reshape(seq.size(0), -1, it_len)
289 .reshape(T, seq.size(0), -1)
292 with open(filename, "w") as f:
295 f.write("cat << EOF\n")
296 # for i in range(seq.size(2)):
297 # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
298 lr, s, a, r = seq2episodes(
299 seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
301 f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
303 f.write("sleep 0.25\n")
304 print(f"Saved {filename}")
307 if __name__ == "__main__":
308 nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
309 states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
310 seq = episodes2seq(states, actions, rewards)
311 lr, s, a, r = seq2episodes(seq, height, width)
312 print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
315 # for s in seq2str(seq):
319 states, actions, rewards = generate_episodes(
320 nb=nb, height=height, width=width, T=T, nb_walls=3
322 seq = episodes2seq(states, actions, rewards)
323 save_seq_as_anim_script(seq, "anim.sh")