- rnd = torch.rand(nb, height, width)
- coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
- rnd = rnd * (1 - wall.clamp(max=1))
- for k in range(nb_coins):
- coins[:, 0] = coins[:, 0] + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
-
- rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
- states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
-
- agent = torch.zeros(states.size(), dtype=torch.int64)
- agent[:, 0, 0, 0] = 1
- agent_actions = torch.randint(5, (nb, T))
- rewards = torch.zeros(nb, T, dtype=torch.int64)
-
- troll = torch.zeros(states.size(), dtype=torch.int64)
- troll[:, 0, -1, -1] = 1
- troll_actions = torch.randint(5, (nb, T))
-
- all_moves = agent.new(nb, 5, height, width)
- for t in range(T - 1):
- all_moves.zero_()
- all_moves[:, 0] = agent[:, t]
- all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
- a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - troll[:, t]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
- all_moves.zero_()
- all_moves[:, 0] = troll[:, t]
- all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
- a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - agent[:, t + 1]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
- hit = (
- (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
- )
- hit = (hit > 0).long()
-
- # assert hit.min() == 0 and hit.max() <= 1