agent = torch.zeros(seq.size(), dtype=torch.int64)
agent[:, 0, 0, 0] = 1
agent_actions = torch.randint(5, (nb, T))
+ rewards = torch.zeros(nb, T, dtype=torch.int64)
+
monster = torch.zeros(seq.size(), dtype=torch.int64)
monster[:, 0, -1, -1] = 1
monster_actions = torch.randint(5, (nb, T))
).long()
monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
+ hit = (
+ (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+ )
+ hit = (hit > 0).long()
+
+ assert hit.min() == 0 and hit.max() <= 1
+
+ rewards[:, t] = -hit + (1 - hit) * agent[:, t + 1, -1, -1]
+
seq += 2 * agent + 3 * monster
- return seq, agent_actions
+ return seq, agent_actions, rewards
######################################################################
-def seq2str(seq, actions=None):
+def seq2str(seq, actions, rewards):
# symbols=" #@$"
+ # vert, hori, cross, thin_hori = "|", "-", "+", "-"
+
symbols = " █@$"
+ vert, hori, cross, thin_hori = "║", "═", "╬", "─"
+ vert, hori, cross, thin_hori = "┃", "━", "╋", "─"
- hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n"
+ # hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n"
+ hline = (cross + hori * seq.size(-1)) * seq.size(1) + cross + "\n"
result = hline
for n in range(seq.size(0)):
for i in range(seq.size(2)):
result += (
- "|"
- + "|".join(
+ vert
+ + vert.join(
["".join([symbols[v.item()] for v in row]) for row in seq[n, :, i]]
)
- + "|"
+ + vert
+ "\n"
)
- result += hline
+ # result += hline
+ result += (vert + thin_hori * seq.size(-1)) * seq.size(1) + vert + "\n"
- if actions is not None:
- result += (
- "|"
- + "|".join(
- ["INESW"[a.item()] + " " * (seq.size(-1) - 1) for a in actions[n]]
- )
- + "|"
- + "\n"
- )
+ def status_bar(a, r):
+ a = "INESW"[a.item()]
+ r = f"{r.item()}"
+ return a + " " * (seq.size(-1) - len(a) - len(r)) + r
+
+ result += (
+ vert
+ + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])])
+ + vert
+ + "\n"
+ )
result += hline
######################################################################
if __name__ == "__main__":
- seq, actions = generate_sequence(40, 4, 6, T=20)
+ seq, actions, rewards = generate_sequence(10, 4, 6, T=20)
- print(seq2str(seq, actions))
+ print(seq2str(seq, actions, rewards))