######################################################################
-nb_states_codes = 5
-nb_actions_codes = 5
-nb_rewards_codes = 3
-nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN
-first_states_code = 0
-first_actions_code = first_states_code + nb_states_codes
-first_rewards_code = first_actions_code + nb_actions_codes
-first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes
-nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes
-
-######################################################################
-
-
-def state2code(r):
- return r + first_states_code
-
-
-def code2state(r):
- return r - first_states_code
-
-
-def action2code(r):
- return r + first_actions_code
-
-
-def code2action(r):
- return r - first_actions_code
+class GreedWorld:
+ def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
+ self.height = height
+ self.width = width
+ self.T = T
+ self.nb_walls = nb_walls
+ self.nb_coins = nb_coins
+
+ self.nb_states_codes = 5
+ self.nb_actions_codes = 5
+ self.nb_rewards_codes = 3
+ self.nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN
+
+ self.first_states_code = 0
+ self.first_actions_code = self.first_states_code + self.nb_states_codes
+ self.first_rewards_code = self.first_actions_code + self.nb_actions_codes
+ self.first_lookahead_rewards_code = (
+ self.first_rewards_code + self.nb_rewards_codes
+ )
+ self.nb_codes = (
+ self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
+ )
+ self.state_len = self.height * self.width
+ self.index_states = 0
+ self.index_reward = self.state_len
+ self.index_lookahead_reward = self.state_len + 1
+ self.index_action = self.state_len + 2
+ self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
-def reward2code(r):
- return r + 1 + first_rewards_code
+ def state2code(self, r):
+ return r + self.first_states_code
+ def code2state(self, r):
+ return r - self.first_states_code
-def code2reward(r):
- return r - first_rewards_code - 1
+ def action2code(self, r):
+ return r + self.first_actions_code
+ def code2action(self, r):
+ return r - self.first_actions_code
-def lookahead_reward2code(r):
- # -1, 0, +1 or 2 for UNKNOWN
- return r + 1 + first_lookahead_rewards_code
+ def reward2code(self, r):
+ return r + 1 + self.first_rewards_code
+ def code2reward(self, r):
+ return r - self.first_rewards_code - 1
-def code2lookahead_reward(r):
- return r - first_lookahead_rewards_code - 1
+ def lookahead_reward2code(self, r):
+ # -1, 0, +1 or 2 for UNKNOWN
+ return r + 1 + self.first_lookahead_rewards_code
+ def code2lookahead_reward(self, r):
+ return r - self.first_lookahead_rewards_code - 1
-######################################################################
+ ######################################################################
+ def generate_episodes(self, nb):
+ rnd = torch.rand(nb, self.height, self.width)
+ rnd[:, 0, :] = 0
+ rnd[:, -1, :] = 0
+ rnd[:, :, 0] = 0
+ rnd[:, :, -1] = 0
+ wall = 0
+ for k in range(self.nb_walls):
+ wall = wall + (
+ rnd.flatten(1).argmax(dim=1)[:, None]
+ == torch.arange(rnd.flatten(1).size(1))[None, :]
+ ).long().reshape(rnd.size())
-def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
- rnd = torch.rand(nb, height, width)
- rnd[:, 0, :] = 0
- rnd[:, -1, :] = 0
- rnd[:, :, 0] = 0
- rnd[:, :, -1] = 0
- wall = 0
- for k in range(nb_walls):
- wall = wall + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
+ rnd = rnd * (1 - wall.clamp(max=1))
+ rnd = torch.rand(nb, self.height, self.width)
+ rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting
+ # position
+ coins = torch.zeros(nb, self.T, self.height, self.width, dtype=torch.int64)
rnd = rnd * (1 - wall.clamp(max=1))
-
- rnd = torch.rand(nb, height, width)
- rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting
- # position
- coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
- rnd = rnd * (1 - wall.clamp(max=1))
- for k in range(nb_coins):
- coins[:, 0] = coins[:, 0] + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
-
- rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
- states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
-
- agent = torch.zeros(states.size(), dtype=torch.int64)
- agent[:, 0, 0, 0] = 1
- agent_actions = torch.randint(5, (nb, T))
- rewards = torch.zeros(nb, T, dtype=torch.int64)
-
- troll = torch.zeros(states.size(), dtype=torch.int64)
- troll[:, 0, -1, -1] = 1
- troll_actions = torch.randint(5, (nb, T))
-
- all_moves = agent.new(nb, 5, height, width)
- for t in range(T - 1):
- all_moves.zero_()
- all_moves[:, 0] = agent[:, t]
- all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
- a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - troll[:, t]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
- all_moves.zero_()
- all_moves[:, 0] = troll[:, t]
- all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
- a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - agent[:, t + 1]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
- hit = (
- (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
- )
- hit = (hit > 0).long()
-
- # assert hit.min() == 0 and hit.max() <= 1
-
- got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
- coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
-
- rewards[:, t + 1] = -hit + (1 - hit) * got_coin
-
- states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
-
- return states, agent_actions, rewards
-
-
-######################################################################
-
-
-def episodes2seq(states, actions, rewards):
- neg = rewards.new_zeros(rewards.size())
- pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1) - 1):
- neg[:, t] = rewards[:, t:].min(dim=-1).values
- pos[:, t] = rewards[:, t:].max(dim=-1).values
- s = (neg < 0).long() * neg + (neg >= 0).long() * pos
-
- return torch.cat(
- [
- lookahead_reward2code(s[:, :, None]),
- state2code(states.flatten(2)),
- action2code(actions[:, :, None]),
- reward2code(rewards[:, :, None]),
- ],
- dim=2,
- ).flatten(1)
-
-
-def seq2episodes(seq, height, width):
- seq = seq.reshape(seq.size(0), -1, height * width + 3)
- lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
- states = code2state(seq[:, :, 1 : height * width + 1])
- states = states.reshape(states.size(0), states.size(1), height, width)
- actions = code2action(seq[:, :, height * width + 1])
- rewards = code2reward(seq[:, :, height * width + 2])
- return lookahead_rewards, states, actions, rewards
-
-
-def seq2str(seq):
- def token2str(t):
- if t >= first_states_code and t < first_states_code + nb_states_codes:
- return " #@T$"[t - first_states_code]
- elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
- return "ISNEW"[t - first_actions_code]
- elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
- return "-0+"[t - first_rewards_code]
- elif (
- t >= first_lookahead_rewards_code
- and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
- ):
- return "n.pU"[t - first_lookahead_rewards_code]
+ for k in range(self.nb_coins):
+ coins[:, 0] = coins[:, 0] + (
+ rnd.flatten(1).argmax(dim=1)[:, None]
+ == torch.arange(rnd.flatten(1).size(1))[None, :]
+ ).long().reshape(rnd.size())
+
+ rnd = rnd * (1 - coins[:, 0].clamp(max=1))
+
+ states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone()
+
+ agent = torch.zeros(states.size(), dtype=torch.int64)
+ agent[:, 0, 0, 0] = 1
+ agent_actions = torch.randint(5, (nb, self.T))
+ rewards = torch.zeros(nb, self.T, dtype=torch.int64)
+
+ troll = torch.zeros(states.size(), dtype=torch.int64)
+ troll[:, 0, -1, -1] = 1
+ troll_actions = torch.randint(5, (nb, self.T))
+
+ all_moves = agent.new(nb, 5, self.height, self.width)
+ for t in range(self.T - 1):
+ all_moves.zero_()
+ all_moves[:, 0] = agent[:, t]
+ all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
+ a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
+ after_move = (all_moves * a).sum(dim=1)
+ collision = (
+ (after_move * (1 - wall) * (1 - troll[:, t]))
+ .flatten(1)
+ .sum(dim=1)[:, None, None]
+ == 0
+ ).long()
+ agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
+
+ all_moves.zero_()
+ all_moves[:, 0] = troll[:, t]
+ all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+ all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+ all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+ all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+ a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
+ after_move = (all_moves * a).sum(dim=1)
+ collision = (
+ (after_move * (1 - wall) * (1 - agent[:, t + 1]))
+ .flatten(1)
+ .sum(dim=1)[:, None, None]
+ == 0
+ ).long()
+ troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
+
+ hit = (
+ (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :])
+ .flatten(1)
+ .sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1])
+ .flatten(1)
+ .sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:])
+ .flatten(1)
+ .sum(dim=1)
+ )
+ hit = (hit > 0).long()
+
+ # assert hit.min() == 0 and hit.max() <= 1
+
+ got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
+ coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
+
+ rewards[:, t + 1] = -hit + (1 - hit) * got_coin
+
+ states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
+
+ return states, agent_actions, rewards
+
+ ######################################################################
+
+ def episodes2seq(self, states, actions, rewards):
+ neg = rewards.new_zeros(rewards.size())
+ pos = rewards.new_zeros(rewards.size())
+ for t in range(neg.size(1) - 1):
+ neg[:, t] = rewards[:, t:].min(dim=-1).values
+ pos[:, t] = rewards[:, t:].max(dim=-1).values
+ s = (neg < 0).long() * neg + (neg >= 0).long() * pos
+
+ return torch.cat(
+ [
+ self.state2code(states.flatten(2)),
+ self.reward2code(rewards[:, :, None]),
+ self.lookahead_reward2code(s[:, :, None]),
+ self.action2code(actions[:, :, None]),
+ ],
+ dim=2,
+ ).flatten(1)
+
+ def seq2episodes(self, seq):
+ seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
+ lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
+ states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+ states = states.reshape(states.size(0), states.size(1), self.height, self.width)
+ actions = self.code2action(seq[:, :, self.height * self.width + 1])
+ rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+ return lookahead_rewards, states, actions, rewards
+
+ def seq2str(self, seq):
+ def token2str(t):
+ if (
+ t >= self.first_states_code
+ and t < self.first_states_code + self.nb_states_codes
+ ):
+ return "_#@T$"[t - self.first_states_code]
+ elif (
+ t >= self.first_actions_code
+ and t < self.first_actions_code + self.nb_actions_codes
+ ):
+ return "ISNEW"[t - self.first_actions_code]
+ elif (
+ t >= self.first_rewards_code
+ and t < self.first_rewards_code + self.nb_rewards_codes
+ ):
+ return "-0+"[t - self.first_rewards_code]
+ elif (
+ t >= self.first_lookahead_rewards_code
+ and t
+ < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
+ ):
+ return "n.pU"[t - self.first_lookahead_rewards_code]
+ else:
+ return "?"
+
+ return ["".join([token2str(x.item()) for x in row]) for row in seq]
+
+ ######################################################################
+
+ def episodes2str(
+ self,
+ lookahead_rewards,
+ states,
+ actions,
+ rewards,
+ unicode=False,
+ ansi_colors=False,
+ ):
+ if unicode:
+ symbols = "·█@T$"
+ # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
+ vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
else:
- return "?"
-
- return ["".join([token2str(x.item()) for x in row]) for row in seq]
-
-
-######################################################################
-
-
-def episodes2str(
- lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False
-):
- if unicode:
- symbols = "·█@T$"
- # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
- vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
- else:
- symbols = " #@T$"
- vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
-
- hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
-
- result = hline
-
- for n in range(states.size(0)):
+ symbols = " #@T$"
+ vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
+
+ hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
+
+ result = hline
+
+ for n in range(states.size(0)):
+
+ def state_symbol(v):
+ v = v.item()
+ return "?" if v < 0 or v >= len(symbols) else symbols[v]
+
+ for i in range(states.size(2)):
+ result += (
+ vert
+ + vert.join(
+ [
+ "".join([state_symbol(v) for v in row])
+ for row in states[n, :, i]
+ ]
+ )
+ + vert
+ + "\n"
+ )
- def state_symbol(v):
- v = v.item()
- return "?" if v < 0 or v >= len(symbols) else symbols[v]
+ # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
+
+ def status_bar(a, r, lr=None):
+ a, r = a.item(), r.item()
+ sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
+ sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
+ if lr is None:
+ sb_lr = ""
+ else:
+ lr = lr.item()
+ sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
+ return (
+ sb_a
+ + "/"
+ + sb_r
+ + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
+ + sb_lr
+ )
- for i in range(states.size(2)):
result += (
vert
+ vert.join(
- ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]]
+ [
+ status_bar(a, r, lr)
+ for a, r, lr in zip(
+ actions[n], rewards[n], lookahead_rewards[n]
+ )
+ ]
)
+ vert
+ "\n"
)
- # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
+ result += hline
- def status_bar(a, r, lr=None):
- a, r = a.item(), r.item()
- sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
- sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
- if lr is None:
- sb_lr = ""
- else:
- lr = lr.item()
- sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
- return (
- sb_a
- + "/"
- + sb_r
- + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
- + sb_lr
- )
-
- result += (
- vert
- + vert.join(
- [
- status_bar(a, r, lr)
- for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n])
- ]
- )
- + vert
- + "\n"
- )
-
- result += hline
+ if ansi_colors:
+ for u, c in [("T", 31), ("@", 32), ("$", 34)]:
+ result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
- if ansi_colors:
- for u, c in [("T", 31), ("@", 32), ("$", 34)]:
- result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+ return result
- return result
+ ######################################################################
+ def save_seq_as_anim_script(self, seq, filename):
+ it_len = self.height * self.width + 3
-######################################################################
-
-
-def save_seq_as_anim_script(seq, filename):
- it_len = height * width + 3
-
- seq = (
- seq.reshape(seq.size(0), -1, it_len)
- .permute(1, 0, 2)
- .reshape(T, seq.size(0), -1)
- )
+ seq = (
+ seq.reshape(seq.size(0), -1, it_len)
+ .permute(1, 0, 2)
+ .reshape(self.T, seq.size(0), -1)
+ )
- with open(filename, "w") as f:
- for t in range(T):
- f.write("clear\n")
- f.write("cat << EOF\n")
- # for i in range(seq.size(2)):
- # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
- lr, s, a, r = seq2episodes(
- seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
- )
- f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
- f.write("EOF\n")
- f.write("sleep 0.25\n")
- print(f"Saved {filename}")
+ with open(filename, "w") as f:
+ for t in range(self.T):
+ # f.write("clear\n")
+ f.write("cat << EOF\n")
+ f.write("\u001b[H")
+ # for i in range(seq.size(2)):
+ # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], self.height, self.width)
+ lr, s, a, r = self.seq2episodes(seq[t : t + 1, :].reshape(8, -1))
+ f.write(self.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+ f.write("EOF\n")
+ f.write("sleep 0.25\n")
+ print(f"Saved {filename}")
if __name__ == "__main__":
- nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
- states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
- seq = episodes2seq(states, actions, rewards)
- lr, s, a, r = seq2episodes(seq, height, width)
- print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-
- # print()
- # for s in seq2str(seq):
- # print(s)
-
- nb, T = 50, 100
- states, actions, rewards = generate_episodes(
- nb=nb, height=height, width=width, T=T, nb_walls=3
- )
- seq = episodes2seq(states, actions, rewards)
- save_seq_as_anim_script(seq, "anim.sh")
+ gw = GreedWorld(height=5, width=7, T=10, nb_walls=4, nb_coins=2)
+ states, actions, rewards = gw.generate_episodes(nb=6)
+ seq = gw.episodes2seq(states, actions, rewards)
+ lr, s, a, r = gw.seq2episodes(seq)
+ print(gw.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+
+ print()
+ for s in gw.seq2str(seq):
+ print(s)
+
+ gw = GreedWorld(height=5, width=7, T=100, nb_walls=4, nb_coins=2)
+ states, actions, rewards = gw.generate_episodes(nb=128)
+ seq = gw.episodes2seq(states, actions, rewards)
+ gw.save_seq_as_anim_script(seq, "anim.sh")
width,
T,
nb_walls,
+ nb_coins,
logger=None,
device=torch.device("cpu"),
):
self.batch_size = batch_size
self.device = device
- self.height = height
- self.width = width
- states, actions, rewards = greed.generate_episodes(
- nb_train_samples + nb_test_samples, height, width, T, nb_walls
+ self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
+
+ states, actions, rewards = self.world.generate_episodes(
+ nb_train_samples + nb_test_samples
)
- seq = greed.episodes2seq(states, actions, rewards)
- # seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3]
+ seq = self.world.episodes2seq(states, actions, rewards)
self.train_input = seq[:nb_train_samples].to(self.device)
self.test_input = seq[nb_train_samples:].to(self.device)
- self.state_len = self.height * self.width
- self.index_lookahead_reward = 0
- self.index_states = 1
- self.index_action = self.state_len + 1
- self.index_reward = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
-
def wipe_lookahead_rewards(self, batch):
t = torch.arange(batch.size(1), device=batch.device)[None, :]
u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
lr_mask = (t <= u).long() * (
- t % self.it_len == self.index_lookahead_reward
+ t % self.world.it_len == self.world.index_lookahead_reward
).long()
- return lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch
+ return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
yield self.wipe_lookahead_rewards(batch)
def vocabulary_size(self):
- return greed.nb_codes
+ return self.world.nb_codes
def thinking_autoregression(
self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
result = self.test_input[:250].clone()
# Erase all the content but that of the first iteration
- result[:, self.it_len :] = -1
+ result[:, self.world.it_len :] = -1
# Set the lookahead_reward of the firs to UNKNOWN
- result[:, self.index_lookahead_reward] = greed.lookahead_reward2code(2)
+ result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
+ 2
+ )
t = torch.arange(result.size(1), device=result.device)[None, :]
for u in tqdm.tqdm(
- range(0, result.size(1), self.it_len),
+ range(0, result.size(1), self.world.it_len),
desc="thinking",
):
# Generate the next state but keep the initial one, the
# UNKNOWN
if u > 0:
result[
- :, u + self.index_lookahead_reward
- ] = greed.lookahead_reward2code(2)
- ar_mask = (t >= u + self.index_states).long() * (
- t < u + self.index_states + self.state_len
+ :, u + self.world.index_lookahead_reward
+ ] = self.world.lookahead_reward2code(2)
+ ar_mask = (t >= u + self.world.index_states).long() * (
+ t < u + self.world.index_states + self.world.state_len
).long()
ar(result, ar_mask)
# Generate the action and reward with lookahead_reward to +1
- result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(1)
- ar_mask = (t >= u + self.index_action).long() * (
- t <= u + self.index_reward
+ result[
+ :, u + self.world.index_lookahead_reward
+ ] = self.world.lookahead_reward2code(1)
+ ar_mask = (t >= u + self.world.index_reward).long() * (
+ t <= u + self.world.index_action
).long()
ar(result, ar_mask)
# Set the lookahead_reward to UNKNOWN for the next iterations
- result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(2)
+ result[
+ :, u + self.world.index_lookahead_reward
+ ] = self.world.lookahead_reward2code(2)
filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
with open(filename, "w") as f:
for n in range(10):
for s in snapshots:
- lr, s, a, r = greed.seq2episodes(
- s[n : n + 1], self.height, self.width
+ lr, s, a, r = self.world.seq2episodes(
+ s[n : n + 1],
)
- str = greed.episodes2str(
+ str = self.world.episodes2str(
lr, s, a, r, unicode=True, ansi_colors=True
)
f.write(str)
# Saving the generated sequences
- lr, s, a, r = greed.seq2episodes(result, self.height, self.width)
- str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+ lr, s, a, r = self.world.seq2episodes(result)
+ str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
with open(filename, "w") as f:
# Saving the ground truth
- lr, s, a, r = greed.seq2episodes(
+ lr, s, a, r = self.world.seq2episodes(
result,
- self.height,
- self.width,
)
- str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+ str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
with open(filename, "w") as f:
# Re-generating from the first frame
ar_mask = (
- torch.arange(result.size(1), device=result.device)
- >= self.height * self.width + 3
+ torch.arange(result.size(1), device=result.device) >= self.world.it_len
).long()[None, :]
ar_mask = ar_mask.expand_as(result)
result *= 1 - ar_mask # paraaaaanoiaaaaaaa
# Saving the generated sequences
- lr, s, a, r = greed.seq2episodes(
+ lr, s, a, r = self.world.seq2episodes(
result,
- self.height,
- self.width,
)
- str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+ str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
with open(filename, "w") as f: