agent_actions = torch.randint(5, (nb, T))
rewards = torch.zeros(nb, T, dtype=torch.int64)
- monster = torch.zeros(states.size(), dtype=torch.int64)
- monster[:, 0, -1, -1] = 1
- monster_actions = torch.randint(5, (nb, T))
+ troll = torch.zeros(states.size(), dtype=torch.int64)
+ troll[:, 0, -1, -1] = 1
+ troll_actions = torch.randint(5, (nb, T))
all_moves = agent.new(nb, 5, height, width)
for t in range(T - 1):
a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
after_move = (all_moves * a).sum(dim=1)
collision = (
- (after_move * (1 - wall) * (1 - monster[:, t]))
+ (after_move * (1 - wall) * (1 - troll[:, t]))
.flatten(1)
.sum(dim=1)[:, None, None]
== 0
agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
all_moves.zero_()
- all_moves[:, 0] = monster[:, t]
- all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
- a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
+ all_moves[:, 0] = troll[:, t]
+ all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+ a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
after_move = (all_moves * a).sum(dim=1)
collision = (
(after_move * (1 - wall) * (1 - agent[:, t + 1]))
.sum(dim=1)[:, None, None]
== 0
).long()
- monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
+ troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
hit = (
- (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+ (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
)
hit = (hit > 0).long()
rewards[:, t + 1] = -hit + (1 - hit) * got_coin
- states = states + 2 * agent + 3 * monster + 4 * coins
+ states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
return states, agent_actions, rewards
result += hline
if ansi_colors:
- for u, c in [("$", 31), ("@", 32)]:
+ for u, c in [("T", 31), ("@", 32), ("$", 34)]:
result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
return result
######################################################################
+
+def save_seq_as_anim_script(seq, filename):
+ it_len = height * width + 3
+
+ seq = (
+ seq.reshape(seq.size(0), -1, it_len)
+ .permute(1, 0, 2)
+ .reshape(T, seq.size(0), -1)
+ )
+
+ with open(filename, "w") as f:
+ for t in range(T):
+ f.write("clear\n")
+ f.write("cat << EOF\n")
+ # for i in range(seq.size(2)):
+ # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
+ lr, s, a, r = seq2episodes(
+ seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
+ )
+ f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+ f.write("EOF\n")
+ f.write("sleep 0.25\n")
+ print(f"Saved {filename}")
+
+
if __name__ == "__main__":
- nb, height, width, T, nb_walls = 5, 5, 7, 10, 5
+ nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
seq = episodes2seq(states, actions, rewards)
lr, s, a, r = seq2episodes(seq, height, width)
print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+
# print()
# for s in seq2str(seq):
# print(s)
+
+ nb, T = 50, 100
+ states, actions, rewards = generate_episodes(
+ nb=nb, height=height, width=width, T=T, nb_walls=3
+ )
+ seq = episodes2seq(states, actions, rewards)
+ save_seq_as_anim_script(seq, "anim.sh")