agent_actions = torch.randint(5, (nb, T))
rewards = torch.zeros(nb, T, dtype=torch.int64)
- monster = torch.zeros(states.size(), dtype=torch.int64)
- monster[:, 0, -1, -1] = 1
- monster_actions = torch.randint(5, (nb, T))
+ troll = torch.zeros(states.size(), dtype=torch.int64)
+ troll[:, 0, -1, -1] = 1
+ troll_actions = torch.randint(5, (nb, T))
all_moves = agent.new(nb, 5, height, width)
for t in range(T - 1):
a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
after_move = (all_moves * a).sum(dim=1)
collision = (
- (after_move * (1 - wall) * (1 - monster[:, t]))
+ (after_move * (1 - wall) * (1 - troll[:, t]))
.flatten(1)
.sum(dim=1)[:, None, None]
== 0
agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
all_moves.zero_()
- all_moves[:, 0] = monster[:, t]
- all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
- a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
+ all_moves[:, 0] = troll[:, t]
+ all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+ a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
after_move = (all_moves * a).sum(dim=1)
collision = (
(after_move * (1 - wall) * (1 - agent[:, t + 1]))
.sum(dim=1)[:, None, None]
== 0
).long()
- monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
+ troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
hit = (
- (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+ (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
)
hit = (hit > 0).long()
rewards[:, t + 1] = -hit + (1 - hit) * got_coin
- states = states + 2 * agent + 3 * monster + 4 * coins
+ states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
return states, agent_actions, rewards
result += hline
if ansi_colors:
- for u, c in [("$", 31), ("@", 32)]:
+ for u, c in [("T", 31), ("@", 32), ("$", 34)]:
result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
return result