From 41164ce7ce1d071a4eb71f72ff277933794cf316 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 27 Mar 2024 21:17:55 +0100 Subject: [PATCH] Update. --- greed.py | 578 +++++++++++++++++++++++++++++-------------------------- main.py | 9 +- mygpt.py | 1 + tasks.py | 78 ++++---- 4 files changed, 343 insertions(+), 323 deletions(-) diff --git a/greed.py b/greed.py index dc11d14..47cfb40 100755 --- a/greed.py +++ b/greed.py @@ -11,315 +11,339 @@ from torch.nn import functional as F ###################################################################### -nb_states_codes = 5 -nb_actions_codes = 5 -nb_rewards_codes = 3 -nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN -first_states_code = 0 -first_actions_code = first_states_code + nb_states_codes -first_rewards_code = first_actions_code + nb_actions_codes -first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes -nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes - -###################################################################### - - -def state2code(r): - return r + first_states_code - - -def code2state(r): - return r - first_states_code - - -def action2code(r): - return r + first_actions_code - - -def code2action(r): - return r - first_actions_code +class GreedWorld: + def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2): + self.height = height + self.width = width + self.T = T + self.nb_walls = nb_walls + self.nb_coins = nb_coins + + self.nb_states_codes = 5 + self.nb_actions_codes = 5 + self.nb_rewards_codes = 3 + self.nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN + + self.first_states_code = 0 + self.first_actions_code = self.first_states_code + self.nb_states_codes + self.first_rewards_code = self.first_actions_code + self.nb_actions_codes + self.first_lookahead_rewards_code = ( + self.first_rewards_code + self.nb_rewards_codes + ) + self.nb_codes = ( + self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes + ) + self.state_len = self.height * self.width + self.index_states = 0 + self.index_reward = self.state_len + self.index_lookahead_reward = self.state_len + 1 + self.index_action = self.state_len + 2 + self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward -def reward2code(r): - return r + 1 + first_rewards_code + def state2code(self, r): + return r + self.first_states_code + def code2state(self, r): + return r - self.first_states_code -def code2reward(r): - return r - first_rewards_code - 1 + def action2code(self, r): + return r + self.first_actions_code + def code2action(self, r): + return r - self.first_actions_code -def lookahead_reward2code(r): - # -1, 0, +1 or 2 for UNKNOWN - return r + 1 + first_lookahead_rewards_code + def reward2code(self, r): + return r + 1 + self.first_rewards_code + def code2reward(self, r): + return r - self.first_rewards_code - 1 -def code2lookahead_reward(r): - return r - first_lookahead_rewards_code - 1 + def lookahead_reward2code(self, r): + # -1, 0, +1 or 2 for UNKNOWN + return r + 1 + self.first_lookahead_rewards_code + def code2lookahead_reward(self, r): + return r - self.first_lookahead_rewards_code - 1 -###################################################################### + ###################################################################### + def generate_episodes(self, nb): + rnd = torch.rand(nb, self.height, self.width) + rnd[:, 0, :] = 0 + rnd[:, -1, :] = 0 + rnd[:, :, 0] = 0 + rnd[:, :, -1] = 0 + wall = 0 + for k in range(self.nb_walls): + wall = wall + ( + rnd.flatten(1).argmax(dim=1)[:, None] + == torch.arange(rnd.flatten(1).size(1))[None, :] + ).long().reshape(rnd.size()) -def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): - rnd = torch.rand(nb, height, width) - rnd[:, 0, :] = 0 - rnd[:, -1, :] = 0 - rnd[:, :, 0] = 0 - rnd[:, :, -1] = 0 - wall = 0 - for k in range(nb_walls): - wall = wall + ( - rnd.flatten(1).argmax(dim=1)[:, None] - == torch.arange(rnd.flatten(1).size(1))[None, :] - ).long().reshape(rnd.size()) + rnd = rnd * (1 - wall.clamp(max=1)) + rnd = torch.rand(nb, self.height, self.width) + rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting + # position + coins = torch.zeros(nb, self.T, self.height, self.width, dtype=torch.int64) rnd = rnd * (1 - wall.clamp(max=1)) - - rnd = torch.rand(nb, height, width) - rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting - # position - coins = torch.zeros(nb, T, height, width, dtype=torch.int64) - rnd = rnd * (1 - wall.clamp(max=1)) - for k in range(nb_coins): - coins[:, 0] = coins[:, 0] + ( - rnd.flatten(1).argmax(dim=1)[:, None] - == torch.arange(rnd.flatten(1).size(1))[None, :] - ).long().reshape(rnd.size()) - - rnd = rnd * (1 - coins[:, 0].clamp(max=1)) - - states = wall[:, None, :, :].expand(-1, T, -1, -1).clone() - - agent = torch.zeros(states.size(), dtype=torch.int64) - agent[:, 0, 0, 0] = 1 - agent_actions = torch.randint(5, (nb, T)) - rewards = torch.zeros(nb, T, dtype=torch.int64) - - troll = torch.zeros(states.size(), dtype=torch.int64) - troll[:, 0, -1, -1] = 1 - troll_actions = torch.randint(5, (nb, T)) - - all_moves = agent.new(nb, 5, height, width) - for t in range(T - 1): - all_moves.zero_() - all_moves[:, 0] = agent[:, t] - all_moves[:, 1, 1:, :] = agent[:, t, :-1, :] - all_moves[:, 2, :-1, :] = agent[:, t, 1:, :] - all_moves[:, 3, :, 1:] = agent[:, t, :, :-1] - all_moves[:, 4, :, :-1] = agent[:, t, :, 1:] - a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None] - after_move = (all_moves * a).sum(dim=1) - collision = ( - (after_move * (1 - wall) * (1 - troll[:, t])) - .flatten(1) - .sum(dim=1)[:, None, None] - == 0 - ).long() - agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move - - all_moves.zero_() - all_moves[:, 0] = troll[:, t] - all_moves[:, 1, 1:, :] = troll[:, t, :-1, :] - all_moves[:, 2, :-1, :] = troll[:, t, 1:, :] - all_moves[:, 3, :, 1:] = troll[:, t, :, :-1] - all_moves[:, 4, :, :-1] = troll[:, t, :, 1:] - a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None] - after_move = (all_moves * a).sum(dim=1) - collision = ( - (after_move * (1 - wall) * (1 - agent[:, t + 1])) - .flatten(1) - .sum(dim=1)[:, None, None] - == 0 - ).long() - troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move - - hit = ( - (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1) - ) - hit = (hit > 0).long() - - # assert hit.min() == 0 and hit.max() <= 1 - - got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1) - coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1]) - - rewards[:, t + 1] = -hit + (1 - hit) * got_coin - - states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll) - - return states, agent_actions, rewards - - -###################################################################### - - -def episodes2seq(states, actions, rewards): - neg = rewards.new_zeros(rewards.size()) - pos = rewards.new_zeros(rewards.size()) - for t in range(neg.size(1) - 1): - neg[:, t] = rewards[:, t:].min(dim=-1).values - pos[:, t] = rewards[:, t:].max(dim=-1).values - s = (neg < 0).long() * neg + (neg >= 0).long() * pos - - return torch.cat( - [ - lookahead_reward2code(s[:, :, None]), - state2code(states.flatten(2)), - action2code(actions[:, :, None]), - reward2code(rewards[:, :, None]), - ], - dim=2, - ).flatten(1) - - -def seq2episodes(seq, height, width): - seq = seq.reshape(seq.size(0), -1, height * width + 3) - lookahead_rewards = code2lookahead_reward(seq[:, :, 0]) - states = code2state(seq[:, :, 1 : height * width + 1]) - states = states.reshape(states.size(0), states.size(1), height, width) - actions = code2action(seq[:, :, height * width + 1]) - rewards = code2reward(seq[:, :, height * width + 2]) - return lookahead_rewards, states, actions, rewards - - -def seq2str(seq): - def token2str(t): - if t >= first_states_code and t < first_states_code + nb_states_codes: - return " #@T$"[t - first_states_code] - elif t >= first_actions_code and t < first_actions_code + nb_actions_codes: - return "ISNEW"[t - first_actions_code] - elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes: - return "-0+"[t - first_rewards_code] - elif ( - t >= first_lookahead_rewards_code - and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes - ): - return "n.pU"[t - first_lookahead_rewards_code] + for k in range(self.nb_coins): + coins[:, 0] = coins[:, 0] + ( + rnd.flatten(1).argmax(dim=1)[:, None] + == torch.arange(rnd.flatten(1).size(1))[None, :] + ).long().reshape(rnd.size()) + + rnd = rnd * (1 - coins[:, 0].clamp(max=1)) + + states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone() + + agent = torch.zeros(states.size(), dtype=torch.int64) + agent[:, 0, 0, 0] = 1 + agent_actions = torch.randint(5, (nb, self.T)) + rewards = torch.zeros(nb, self.T, dtype=torch.int64) + + troll = torch.zeros(states.size(), dtype=torch.int64) + troll[:, 0, -1, -1] = 1 + troll_actions = torch.randint(5, (nb, self.T)) + + all_moves = agent.new(nb, 5, self.height, self.width) + for t in range(self.T - 1): + all_moves.zero_() + all_moves[:, 0] = agent[:, t] + all_moves[:, 1, 1:, :] = agent[:, t, :-1, :] + all_moves[:, 2, :-1, :] = agent[:, t, 1:, :] + all_moves[:, 3, :, 1:] = agent[:, t, :, :-1] + all_moves[:, 4, :, :-1] = agent[:, t, :, 1:] + a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None] + after_move = (all_moves * a).sum(dim=1) + collision = ( + (after_move * (1 - wall) * (1 - troll[:, t])) + .flatten(1) + .sum(dim=1)[:, None, None] + == 0 + ).long() + agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move + + all_moves.zero_() + all_moves[:, 0] = troll[:, t] + all_moves[:, 1, 1:, :] = troll[:, t, :-1, :] + all_moves[:, 2, :-1, :] = troll[:, t, 1:, :] + all_moves[:, 3, :, 1:] = troll[:, t, :, :-1] + all_moves[:, 4, :, :-1] = troll[:, t, :, 1:] + a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None] + after_move = (all_moves * a).sum(dim=1) + collision = ( + (after_move * (1 - wall) * (1 - agent[:, t + 1])) + .flatten(1) + .sum(dim=1)[:, None, None] + == 0 + ).long() + troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move + + hit = ( + (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]) + .flatten(1) + .sum(dim=1) + + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]) + .flatten(1) + .sum(dim=1) + + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]) + .flatten(1) + .sum(dim=1) + ) + hit = (hit > 0).long() + + # assert hit.min() == 0 and hit.max() <= 1 + + got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1) + coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1]) + + rewards[:, t + 1] = -hit + (1 - hit) * got_coin + + states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll) + + return states, agent_actions, rewards + + ###################################################################### + + def episodes2seq(self, states, actions, rewards): + neg = rewards.new_zeros(rewards.size()) + pos = rewards.new_zeros(rewards.size()) + for t in range(neg.size(1) - 1): + neg[:, t] = rewards[:, t:].min(dim=-1).values + pos[:, t] = rewards[:, t:].max(dim=-1).values + s = (neg < 0).long() * neg + (neg >= 0).long() * pos + + return torch.cat( + [ + self.state2code(states.flatten(2)), + self.reward2code(rewards[:, :, None]), + self.lookahead_reward2code(s[:, :, None]), + self.action2code(actions[:, :, None]), + ], + dim=2, + ).flatten(1) + + def seq2episodes(self, seq): + seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3) + lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0]) + states = self.code2state(seq[:, :, 1 : self.height * self.width + 1]) + states = states.reshape(states.size(0), states.size(1), self.height, self.width) + actions = self.code2action(seq[:, :, self.height * self.width + 1]) + rewards = self.code2reward(seq[:, :, self.height * self.width + 2]) + return lookahead_rewards, states, actions, rewards + + def seq2str(self, seq): + def token2str(t): + if ( + t >= self.first_states_code + and t < self.first_states_code + self.nb_states_codes + ): + return "_#@T$"[t - self.first_states_code] + elif ( + t >= self.first_actions_code + and t < self.first_actions_code + self.nb_actions_codes + ): + return "ISNEW"[t - self.first_actions_code] + elif ( + t >= self.first_rewards_code + and t < self.first_rewards_code + self.nb_rewards_codes + ): + return "-0+"[t - self.first_rewards_code] + elif ( + t >= self.first_lookahead_rewards_code + and t + < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes + ): + return "n.pU"[t - self.first_lookahead_rewards_code] + else: + return "?" + + return ["".join([token2str(x.item()) for x in row]) for row in seq] + + ###################################################################### + + def episodes2str( + self, + lookahead_rewards, + states, + actions, + rewards, + unicode=False, + ansi_colors=False, + ): + if unicode: + symbols = "·█@T$" + # vert, hori, cross, thin_hori = "║", "═", "╬", "─" + vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─" else: - return "?" - - return ["".join([token2str(x.item()) for x in row]) for row in seq] - - -###################################################################### - - -def episodes2str( - lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False -): - if unicode: - symbols = "·█@T$" - # vert, hori, cross, thin_hori = "║", "═", "╬", "─" - vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─" - else: - symbols = " #@T$" - vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-" - - hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n" - - result = hline - - for n in range(states.size(0)): + symbols = " #@T$" + vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-" + + hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n" + + result = hline + + for n in range(states.size(0)): + + def state_symbol(v): + v = v.item() + return "?" if v < 0 or v >= len(symbols) else symbols[v] + + for i in range(states.size(2)): + result += ( + vert + + vert.join( + [ + "".join([state_symbol(v) for v in row]) + for row in states[n, :, i] + ] + ) + + vert + + "\n" + ) - def state_symbol(v): - v = v.item() - return "?" if v < 0 or v >= len(symbols) else symbols[v] + # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n" + + def status_bar(a, r, lr=None): + a, r = a.item(), r.item() + sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?" + sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?" + if lr is None: + sb_lr = "" + else: + lr = lr.item() + sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?" + return ( + sb_a + + "/" + + sb_r + + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr)) + + sb_lr + ) - for i in range(states.size(2)): result += ( vert + vert.join( - ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]] + [ + status_bar(a, r, lr) + for a, r, lr in zip( + actions[n], rewards[n], lookahead_rewards[n] + ) + ] ) + vert + "\n" ) - # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n" + result += hline - def status_bar(a, r, lr=None): - a, r = a.item(), r.item() - sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?" - sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?" - if lr is None: - sb_lr = "" - else: - lr = lr.item() - sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?" - return ( - sb_a - + "/" - + sb_r - + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr)) - + sb_lr - ) - - result += ( - vert - + vert.join( - [ - status_bar(a, r, lr) - for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n]) - ] - ) - + vert - + "\n" - ) - - result += hline + if ansi_colors: + for u, c in [("T", 31), ("@", 32), ("$", 34)]: + result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m") - if ansi_colors: - for u, c in [("T", 31), ("@", 32), ("$", 34)]: - result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m") + return result - return result + ###################################################################### + def save_seq_as_anim_script(self, seq, filename): + it_len = self.height * self.width + 3 -###################################################################### - - -def save_seq_as_anim_script(seq, filename): - it_len = height * width + 3 - - seq = ( - seq.reshape(seq.size(0), -1, it_len) - .permute(1, 0, 2) - .reshape(T, seq.size(0), -1) - ) + seq = ( + seq.reshape(seq.size(0), -1, it_len) + .permute(1, 0, 2) + .reshape(self.T, seq.size(0), -1) + ) - with open(filename, "w") as f: - for t in range(T): - f.write("clear\n") - f.write("cat << EOF\n") - # for i in range(seq.size(2)): - # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width) - lr, s, a, r = seq2episodes( - seq[t : t + 1, :].reshape(5, 10 * it_len), height, width - ) - f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) - f.write("EOF\n") - f.write("sleep 0.25\n") - print(f"Saved {filename}") + with open(filename, "w") as f: + for t in range(self.T): + # f.write("clear\n") + f.write("cat << EOF\n") + f.write("\u001b[H") + # for i in range(seq.size(2)): + # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], self.height, self.width) + lr, s, a, r = self.seq2episodes(seq[t : t + 1, :].reshape(8, -1)) + f.write(self.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + f.write("EOF\n") + f.write("sleep 0.25\n") + print(f"Saved {filename}") if __name__ == "__main__": - nb, height, width, T, nb_walls = 6, 5, 7, 10, 5 - states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls) - seq = episodes2seq(states, actions, rewards) - lr, s, a, r = seq2episodes(seq, height, width) - print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) - - # print() - # for s in seq2str(seq): - # print(s) - - nb, T = 50, 100 - states, actions, rewards = generate_episodes( - nb=nb, height=height, width=width, T=T, nb_walls=3 - ) - seq = episodes2seq(states, actions, rewards) - save_seq_as_anim_script(seq, "anim.sh") + gw = GreedWorld(height=5, width=7, T=10, nb_walls=4, nb_coins=2) + states, actions, rewards = gw.generate_episodes(nb=6) + seq = gw.episodes2seq(states, actions, rewards) + lr, s, a, r = gw.seq2episodes(seq) + print(gw.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + + print() + for s in gw.seq2str(seq): + print(s) + + gw = GreedWorld(height=5, width=7, T=100, nb_walls=4, nb_coins=2) + states, actions, rewards = gw.generate_episodes(nb=128) + seq = gw.episodes2seq(states, actions, rewards) + gw.save_seq_as_anim_script(seq, "anim.sh") diff --git a/main.py b/main.py index 2339dcf..0f2cb61 100755 --- a/main.py +++ b/main.py @@ -186,6 +186,8 @@ parser.add_argument("--greed_T", type=int, default=25) parser.add_argument("--greed_nb_walls", type=int, default=5) +parser.add_argument("--greed_nb_coins", type=int, default=2) + ###################################################################### args = parser.parse_args() @@ -625,6 +627,7 @@ elif args.task == "greed": width=args.greed_width, T=args.greed_T, nb_walls=args.greed_nb_walls, + nb_coins=args.greed_nb_coins, logger=log_string, device=device, ) @@ -700,8 +703,6 @@ if args.task == "expr" and args.expr_input_file is not None: ###################################################################### -nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default - # Compute the entropy of the training tokens token_count = 0 @@ -770,7 +771,7 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") nb_samples_seen = 0 -if nb_epochs_finished >= nb_epochs: +if nb_epochs_finished >= args.nb_epochs: task.produce_results( n_epoch=nb_epochs_finished, model=model, @@ -781,7 +782,7 @@ if nb_epochs_finished >= nb_epochs: time_pred_result = None -for n_epoch in range(nb_epochs_finished, nb_epochs): +for n_epoch in range(nb_epochs_finished, args.nb_epochs): learning_rate = learning_rate_schedule[n_epoch] log_string(f"learning_rate {learning_rate}") diff --git a/mygpt.py b/mygpt.py index 77c29ce..131c822 100755 --- a/mygpt.py +++ b/mygpt.py @@ -264,6 +264,7 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) diff --git a/tasks.py b/tasks.py index aa5df72..6a7e639 100755 --- a/tasks.py +++ b/tasks.py @@ -1880,6 +1880,7 @@ class Greed(Task): width, T, nb_walls, + nb_coins, logger=None, device=torch.device("cpu"), ): @@ -1887,32 +1888,24 @@ class Greed(Task): self.batch_size = batch_size self.device = device - self.height = height - self.width = width - states, actions, rewards = greed.generate_episodes( - nb_train_samples + nb_test_samples, height, width, T, nb_walls + self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins) + + states, actions, rewards = self.world.generate_episodes( + nb_train_samples + nb_test_samples ) - seq = greed.episodes2seq(states, actions, rewards) - # seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3] + seq = self.world.episodes2seq(states, actions, rewards) self.train_input = seq[:nb_train_samples].to(self.device) self.test_input = seq[nb_train_samples:].to(self.device) - self.state_len = self.height * self.width - self.index_lookahead_reward = 0 - self.index_states = 1 - self.index_action = self.state_len + 1 - self.index_reward = self.state_len + 2 - self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward - def wipe_lookahead_rewards(self, batch): t = torch.arange(batch.size(1), device=batch.device)[None, :] u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device) lr_mask = (t <= u).long() * ( - t % self.it_len == self.index_lookahead_reward + t % self.world.it_len == self.world.index_lookahead_reward ).long() - return lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch + return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1927,7 +1920,7 @@ class Greed(Task): yield self.wipe_lookahead_rewards(batch) def vocabulary_size(self): - return greed.nb_codes + return self.world.nb_codes def thinking_autoregression( self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 @@ -1954,14 +1947,16 @@ class Greed(Task): result = self.test_input[:250].clone() # Erase all the content but that of the first iteration - result[:, self.it_len :] = -1 + result[:, self.world.it_len :] = -1 # Set the lookahead_reward of the firs to UNKNOWN - result[:, self.index_lookahead_reward] = greed.lookahead_reward2code(2) + result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code( + 2 + ) t = torch.arange(result.size(1), device=result.device)[None, :] for u in tqdm.tqdm( - range(0, result.size(1), self.it_len), + range(0, result.size(1), self.world.it_len), desc="thinking", ): # Generate the next state but keep the initial one, the @@ -1969,31 +1964,35 @@ class Greed(Task): # UNKNOWN if u > 0: result[ - :, u + self.index_lookahead_reward - ] = greed.lookahead_reward2code(2) - ar_mask = (t >= u + self.index_states).long() * ( - t < u + self.index_states + self.state_len + :, u + self.world.index_lookahead_reward + ] = self.world.lookahead_reward2code(2) + ar_mask = (t >= u + self.world.index_states).long() * ( + t < u + self.world.index_states + self.world.state_len ).long() ar(result, ar_mask) # Generate the action and reward with lookahead_reward to +1 - result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(1) - ar_mask = (t >= u + self.index_action).long() * ( - t <= u + self.index_reward + result[ + :, u + self.world.index_lookahead_reward + ] = self.world.lookahead_reward2code(1) + ar_mask = (t >= u + self.world.index_reward).long() * ( + t <= u + self.world.index_action ).long() ar(result, ar_mask) # Set the lookahead_reward to UNKNOWN for the next iterations - result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(2) + result[ + :, u + self.world.index_lookahead_reward + ] = self.world.lookahead_reward2code(2) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: for n in range(10): for s in snapshots: - lr, s, a, r = greed.seq2episodes( - s[n : n + 1], self.height, self.width + lr, s, a, r = self.world.seq2episodes( + s[n : n + 1], ) - str = greed.episodes2str( + str = self.world.episodes2str( lr, s, a, r, unicode=True, ansi_colors=True ) f.write(str) @@ -2001,8 +2000,8 @@ class Greed(Task): # Saving the generated sequences - lr, s, a, r = greed.seq2episodes(result, self.height, self.width) - str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) + lr, s, a, r = self.world.seq2episodes(result) + str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: @@ -2016,12 +2015,10 @@ class Greed(Task): # Saving the ground truth - lr, s, a, r = greed.seq2episodes( + lr, s, a, r = self.world.seq2episodes( result, - self.height, - self.width, ) - str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) + str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: @@ -2031,8 +2028,7 @@ class Greed(Task): # Re-generating from the first frame ar_mask = ( - torch.arange(result.size(1), device=result.device) - >= self.height * self.width + 3 + torch.arange(result.size(1), device=result.device) >= self.world.it_len ).long()[None, :] ar_mask = ar_mask.expand_as(result) result *= 1 - ar_mask # paraaaaanoiaaaaaaa @@ -2048,12 +2044,10 @@ class Greed(Task): # Saving the generated sequences - lr, s, a, r = greed.seq2episodes( + lr, s, a, r = self.world.seq2episodes( result, - self.height, - self.width, ) - str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) + str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True) filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt") with open(filename, "w") as f: -- 2.39.5