-
- rnd = torch.rand(nb, height, width)
- rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting
- # position
- coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
- rnd = rnd * (1 - wall.clamp(max=1))
- for k in range(nb_coins):
- coins[:, 0] = coins[:, 0] + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
-
- rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
- states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
-
- agent = torch.zeros(states.size(), dtype=torch.int64)
- agent[:, 0, 0, 0] = 1
- agent_actions = torch.randint(5, (nb, T))
- rewards = torch.zeros(nb, T, dtype=torch.int64)
-
- troll = torch.zeros(states.size(), dtype=torch.int64)
- troll[:, 0, -1, -1] = 1
- troll_actions = torch.randint(5, (nb, T))
-
- all_moves = agent.new(nb, 5, height, width)
- for t in range(T - 1):
- all_moves.zero_()
- all_moves[:, 0] = agent[:, t]
- all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
- a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - troll[:, t]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
- all_moves.zero_()
- all_moves[:, 0] = troll[:, t]
- all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
- a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - agent[:, t + 1]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
- hit = (
- (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
- )
- hit = (hit > 0).long()
-
- # assert hit.min() == 0 and hit.max() <= 1
-
- got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
- coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
-
- rewards[:, t + 1] = -hit + (1 - hit) * got_coin
-
- states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
-
- return states, agent_actions, rewards
-
-
-######################################################################
-
-
-def episodes2seq(states, actions, rewards):
- neg = rewards.new_zeros(rewards.size())
- pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1) - 1):
- neg[:, t] = rewards[:, t:].min(dim=-1).values
- pos[:, t] = rewards[:, t:].max(dim=-1).values
- s = (neg < 0).long() * neg + (neg >= 0).long() * pos
-
- return torch.cat(
- [
- lookahead_reward2code(s[:, :, None]),
- state2code(states.flatten(2)),
- action2code(actions[:, :, None]),
- reward2code(rewards[:, :, None]),
- ],
- dim=2,
- ).flatten(1)
-
-
-def seq2episodes(seq, height, width):
- seq = seq.reshape(seq.size(0), -1, height * width + 3)
- lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
- states = code2state(seq[:, :, 1 : height * width + 1])
- states = states.reshape(states.size(0), states.size(1), height, width)
- actions = code2action(seq[:, :, height * width + 1])
- rewards = code2reward(seq[:, :, height * width + 2])
- return lookahead_rewards, states, actions, rewards
-
-
-def seq2str(seq):
- def token2str(t):
- if t >= first_states_code and t < first_states_code + nb_states_codes:
- return " #@T$"[t - first_states_code]
- elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
- return "ISNEW"[t - first_actions_code]
- elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
- return "-0+"[t - first_rewards_code]
- elif (
- t >= first_lookahead_rewards_code
- and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
- ):
- return "n.pU"[t - first_lookahead_rewards_code]
+ for k in range(self.nb_coins):
+ coins[:, 0] = coins[:, 0] + (
+ rnd.flatten(1).argmax(dim=1)[:, None]
+ == torch.arange(rnd.flatten(1).size(1))[None, :]
+ ).long().reshape(rnd.size())
+
+ rnd = rnd * (1 - coins[:, 0].clamp(max=1))
+
+ states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone()
+
+ agent = torch.zeros(states.size(), dtype=torch.int64)
+ agent[:, 0, 0, 0] = 1
+ agent_actions = torch.randint(5, (nb, self.T))
+ rewards = torch.zeros(nb, self.T, dtype=torch.int64)
+
+ troll = torch.zeros(states.size(), dtype=torch.int64)
+ troll[:, 0, -1, -1] = 1
+ troll_actions = torch.randint(5, (nb, self.T))
+
+ all_moves = agent.new(nb, 5, self.height, self.width)
+ for t in range(self.T - 1):
+ all_moves.zero_()
+ all_moves[:, 0] = agent[:, t]
+ all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
+ a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
+ after_move = (all_moves * a).sum(dim=1)
+ collision = (
+ (after_move * (1 - wall) * (1 - troll[:, t]))
+ .flatten(1)
+ .sum(dim=1)[:, None, None]
+ == 0
+ ).long()
+ agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
+
+ all_moves.zero_()
+ all_moves[:, 0] = troll[:, t]
+ all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+ a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
+ after_move = (all_moves * a).sum(dim=1)
+ collision = (
+ (after_move * (1 - wall) * (1 - agent[:, t + 1]))
+ .flatten(1)
+ .sum(dim=1)[:, None, None]
+ == 0
+ ).long()
+ troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
+
+ hit = (
+ (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :])
+ .flatten(1)
+ .sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1])
+ .flatten(1)
+ .sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:])
+ .flatten(1)
+ .sum(dim=1)
+ )
+ hit = (hit > 0).long()
+
+ # assert hit.min() == 0 and hit.max() <= 1
+
+ got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
+ coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
+
+ rewards[:, t + 1] = -hit + (1 - hit) * got_coin
+
+ states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
+
+ return states, agent_actions, rewards
+
+ ######################################################################
+
+ def episodes2seq(self, states, actions, rewards):
+ neg = rewards.new_zeros(rewards.size())
+ pos = rewards.new_zeros(rewards.size())
+ for t in range(neg.size(1) - 1):
+ neg[:, t] = rewards[:, t:].min(dim=-1).values
+ pos[:, t] = rewards[:, t:].max(dim=-1).values
+ s = (neg < 0).long() * neg + (neg >= 0).long() * pos
+
+ return torch.cat(
+ [
+ self.state2code(states.flatten(2)),
+ self.reward2code(rewards[:, :, None]),
+ self.lookahead_reward2code(s[:, :, None]),
+ self.action2code(actions[:, :, None]),
+ ],
+ dim=2,
+ ).flatten(1)
+
+ def seq2episodes(self, seq):
+ seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
+ lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
+ states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+ states = states.reshape(states.size(0), states.size(1), self.height, self.width)
+ actions = self.code2action(seq[:, :, self.height * self.width + 1])
+ rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+ return lookahead_rewards, states, actions, rewards
+
+ def seq2str(self, seq):
+ def token2str(t):
+ if (
+ t >= self.first_states_code
+ and t < self.first_states_code + self.nb_states_codes
+ ):
+ return "_#@T$"[t - self.first_states_code]
+ elif (
+ t >= self.first_actions_code
+ and t < self.first_actions_code + self.nb_actions_codes
+ ):
+ return "ISNEW"[t - self.first_actions_code]
+ elif (
+ t >= self.first_rewards_code
+ and t < self.first_rewards_code + self.nb_rewards_codes
+ ):
+ return "-0+"[t - self.first_rewards_code]
+ elif (
+ t >= self.first_lookahead_rewards_code
+ and t
+ < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
+ ):
+ return "n.pU"[t - self.first_lookahead_rewards_code]
+ else:
+ return "?"
+
+ return ["".join([token2str(x.item()) for x in row]) for row in seq]
+
+ ######################################################################
+
+ def episodes2str(
+ self,
+ lookahead_rewards,
+ states,
+ actions,
+ rewards,
+ unicode=False,
+ ansi_colors=False,
+ ):
+ if unicode:
+ symbols = "·█@T$"
+ # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
+ vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"