Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 27 Mar 2024 20:17:55 +0000 (21:17 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 27 Mar 2024 20:17:55 +0000 (21:17 +0100)
greed.py
main.py
mygpt.py
tasks.py

index dc11d14..47cfb40 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -11,315 +11,339 @@ from torch.nn import functional as F
 
 ######################################################################
 
-nb_states_codes = 5
-nb_actions_codes = 5
-nb_rewards_codes = 3
-nb_lookahead_rewards_codes = 4  # stands for -1, 0, +1, and UNKNOWN
 
-first_states_code = 0
-first_actions_code = first_states_code + nb_states_codes
-first_rewards_code = first_actions_code + nb_actions_codes
-first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes
-nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes
-
-######################################################################
-
-
-def state2code(r):
-    return r + first_states_code
-
-
-def code2state(r):
-    return r - first_states_code
-
-
-def action2code(r):
-    return r + first_actions_code
-
-
-def code2action(r):
-    return r - first_actions_code
+class GreedWorld:
+    def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
+        self.height = height
+        self.width = width
+        self.T = T
+        self.nb_walls = nb_walls
+        self.nb_coins = nb_coins
+
+        self.nb_states_codes = 5
+        self.nb_actions_codes = 5
+        self.nb_rewards_codes = 3
+        self.nb_lookahead_rewards_codes = 4  # stands for -1, 0, +1, and UNKNOWN
+
+        self.first_states_code = 0
+        self.first_actions_code = self.first_states_code + self.nb_states_codes
+        self.first_rewards_code = self.first_actions_code + self.nb_actions_codes
+        self.first_lookahead_rewards_code = (
+            self.first_rewards_code + self.nb_rewards_codes
+        )
+        self.nb_codes = (
+            self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
+        )
 
+        self.state_len = self.height * self.width
+        self.index_states = 0
+        self.index_reward = self.state_len
+        self.index_lookahead_reward = self.state_len + 1
+        self.index_action = self.state_len + 2
+        self.it_len = self.state_len + 3  # lookahead_reward / state / action / reward
 
-def reward2code(r):
-    return r + 1 + first_rewards_code
+    def state2code(self, r):
+        return r + self.first_states_code
 
+    def code2state(self, r):
+        return r - self.first_states_code
 
-def code2reward(r):
-    return r - first_rewards_code - 1
+    def action2code(self, r):
+        return r + self.first_actions_code
 
+    def code2action(self, r):
+        return r - self.first_actions_code
 
-def lookahead_reward2code(r):
-    # -1, 0, +1 or 2 for UNKNOWN
-    return r + 1 + first_lookahead_rewards_code
+    def reward2code(self, r):
+        return r + 1 + self.first_rewards_code
 
+    def code2reward(self, r):
+        return r - self.first_rewards_code - 1
 
-def code2lookahead_reward(r):
-    return r - first_lookahead_rewards_code - 1
+    def lookahead_reward2code(self, r):
+        # -1, 0, +1 or 2 for UNKNOWN
+        return r + 1 + self.first_lookahead_rewards_code
 
+    def code2lookahead_reward(self, r):
+        return r - self.first_lookahead_rewards_code - 1
 
-######################################################################
+    ######################################################################
 
+    def generate_episodes(self, nb):
+        rnd = torch.rand(nb, self.height, self.width)
+        rnd[:, 0, :] = 0
+        rnd[:, -1, :] = 0
+        rnd[:, :, 0] = 0
+        rnd[:, :, -1] = 0
+        wall = 0
+        for k in range(self.nb_walls):
+            wall = wall + (
+                rnd.flatten(1).argmax(dim=1)[:, None]
+                == torch.arange(rnd.flatten(1).size(1))[None, :]
+            ).long().reshape(rnd.size())
 
-def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
-    rnd = torch.rand(nb, height, width)
-    rnd[:, 0, :] = 0
-    rnd[:, -1, :] = 0
-    rnd[:, :, 0] = 0
-    rnd[:, :, -1] = 0
-    wall = 0
-    for k in range(nb_walls):
-        wall = wall + (
-            rnd.flatten(1).argmax(dim=1)[:, None]
-            == torch.arange(rnd.flatten(1).size(1))[None, :]
-        ).long().reshape(rnd.size())
+            rnd = rnd * (1 - wall.clamp(max=1))
 
+        rnd = torch.rand(nb, self.height, self.width)
+        rnd[:, 0, 0] = 0  # Do not put coin at the agent's starting
+        # position
+        coins = torch.zeros(nb, self.T, self.height, self.width, dtype=torch.int64)
         rnd = rnd * (1 - wall.clamp(max=1))
-
-    rnd = torch.rand(nb, height, width)
-    rnd[:, 0, 0] = 0  # Do not put coin at the agent's starting
-    # position
-    coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
-    rnd = rnd * (1 - wall.clamp(max=1))
-    for k in range(nb_coins):
-        coins[:, 0] = coins[:, 0] + (
-            rnd.flatten(1).argmax(dim=1)[:, None]
-            == torch.arange(rnd.flatten(1).size(1))[None, :]
-        ).long().reshape(rnd.size())
-
-        rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
-    states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
-
-    agent = torch.zeros(states.size(), dtype=torch.int64)
-    agent[:, 0, 0, 0] = 1
-    agent_actions = torch.randint(5, (nb, T))
-    rewards = torch.zeros(nb, T, dtype=torch.int64)
-
-    troll = torch.zeros(states.size(), dtype=torch.int64)
-    troll[:, 0, -1, -1] = 1
-    troll_actions = torch.randint(5, (nb, T))
-
-    all_moves = agent.new(nb, 5, height, width)
-    for t in range(T - 1):
-        all_moves.zero_()
-        all_moves[:, 0] = agent[:, t]
-        all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
-        all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
-        all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
-        all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
-        a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
-        after_move = (all_moves * a).sum(dim=1)
-        collision = (
-            (after_move * (1 - wall) * (1 - troll[:, t]))
-            .flatten(1)
-            .sum(dim=1)[:, None, None]
-            == 0
-        ).long()
-        agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
-        all_moves.zero_()
-        all_moves[:, 0] = troll[:, t]
-        all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
-        all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
-        all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
-        all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
-        a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
-        after_move = (all_moves * a).sum(dim=1)
-        collision = (
-            (after_move * (1 - wall) * (1 - agent[:, t + 1]))
-            .flatten(1)
-            .sum(dim=1)[:, None, None]
-            == 0
-        ).long()
-        troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
-        hit = (
-            (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
-        )
-        hit = (hit > 0).long()
-
-        # assert hit.min() == 0 and hit.max() <= 1
-
-        got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
-        coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
-
-        rewards[:, t + 1] = -hit + (1 - hit) * got_coin
-
-    states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
-
-    return states, agent_actions, rewards
-
-
-######################################################################
-
-
-def episodes2seq(states, actions, rewards):
-    neg = rewards.new_zeros(rewards.size())
-    pos = rewards.new_zeros(rewards.size())
-    for t in range(neg.size(1) - 1):
-        neg[:, t] = rewards[:, t:].min(dim=-1).values
-        pos[:, t] = rewards[:, t:].max(dim=-1).values
-    s = (neg < 0).long() * neg + (neg >= 0).long() * pos
-
-    return torch.cat(
-        [
-            lookahead_reward2code(s[:, :, None]),
-            state2code(states.flatten(2)),
-            action2code(actions[:, :, None]),
-            reward2code(rewards[:, :, None]),
-        ],
-        dim=2,
-    ).flatten(1)
-
-
-def seq2episodes(seq, height, width):
-    seq = seq.reshape(seq.size(0), -1, height * width + 3)
-    lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
-    states = code2state(seq[:, :, 1 : height * width + 1])
-    states = states.reshape(states.size(0), states.size(1), height, width)
-    actions = code2action(seq[:, :, height * width + 1])
-    rewards = code2reward(seq[:, :, height * width + 2])
-    return lookahead_rewards, states, actions, rewards
-
-
-def seq2str(seq):
-    def token2str(t):
-        if t >= first_states_code and t < first_states_code + nb_states_codes:
-            return " #@T$"[t - first_states_code]
-        elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
-            return "ISNEW"[t - first_actions_code]
-        elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
-            return "-0+"[t - first_rewards_code]
-        elif (
-            t >= first_lookahead_rewards_code
-            and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
-        ):
-            return "n.pU"[t - first_lookahead_rewards_code]
+        for k in range(self.nb_coins):
+            coins[:, 0] = coins[:, 0] + (
+                rnd.flatten(1).argmax(dim=1)[:, None]
+                == torch.arange(rnd.flatten(1).size(1))[None, :]
+            ).long().reshape(rnd.size())
+
+            rnd = rnd * (1 - coins[:, 0].clamp(max=1))
+
+        states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone()
+
+        agent = torch.zeros(states.size(), dtype=torch.int64)
+        agent[:, 0, 0, 0] = 1
+        agent_actions = torch.randint(5, (nb, self.T))
+        rewards = torch.zeros(nb, self.T, dtype=torch.int64)
+
+        troll = torch.zeros(states.size(), dtype=torch.int64)
+        troll[:, 0, -1, -1] = 1
+        troll_actions = torch.randint(5, (nb, self.T))
+
+        all_moves = agent.new(nb, 5, self.height, self.width)
+        for t in range(self.T - 1):
+            all_moves.zero_()
+            all_moves[:, 0] = agent[:, t]
+            all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
+            all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
+            all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
+            all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
+            a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
+            after_move = (all_moves * a).sum(dim=1)
+            collision = (
+                (after_move * (1 - wall) * (1 - troll[:, t]))
+                .flatten(1)
+                .sum(dim=1)[:, None, None]
+                == 0
+            ).long()
+            agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
+
+            all_moves.zero_()
+            all_moves[:, 0] = troll[:, t]
+            all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+            all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+            all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+            all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+            a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
+            after_move = (all_moves * a).sum(dim=1)
+            collision = (
+                (after_move * (1 - wall) * (1 - agent[:, t + 1]))
+                .flatten(1)
+                .sum(dim=1)[:, None, None]
+                == 0
+            ).long()
+            troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
+
+            hit = (
+                (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+                + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :])
+                .flatten(1)
+                .sum(dim=1)
+                + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1])
+                .flatten(1)
+                .sum(dim=1)
+                + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:])
+                .flatten(1)
+                .sum(dim=1)
+            )
+            hit = (hit > 0).long()
+
+            # assert hit.min() == 0 and hit.max() <= 1
+
+            got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
+            coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
+
+            rewards[:, t + 1] = -hit + (1 - hit) * got_coin
+
+        states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
+
+        return states, agent_actions, rewards
+
+    ######################################################################
+
+    def episodes2seq(self, states, actions, rewards):
+        neg = rewards.new_zeros(rewards.size())
+        pos = rewards.new_zeros(rewards.size())
+        for t in range(neg.size(1) - 1):
+            neg[:, t] = rewards[:, t:].min(dim=-1).values
+            pos[:, t] = rewards[:, t:].max(dim=-1).values
+        s = (neg < 0).long() * neg + (neg >= 0).long() * pos
+
+        return torch.cat(
+            [
+                self.state2code(states.flatten(2)),
+                self.reward2code(rewards[:, :, None]),
+                self.lookahead_reward2code(s[:, :, None]),
+                self.action2code(actions[:, :, None]),
+            ],
+            dim=2,
+        ).flatten(1)
+
+    def seq2episodes(self, seq):
+        seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
+        lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
+        states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+        states = states.reshape(states.size(0), states.size(1), self.height, self.width)
+        actions = self.code2action(seq[:, :, self.height * self.width + 1])
+        rewards = self.code2reward(seq[:, :, self.height * self.width + 2])
+        return lookahead_rewards, states, actions, rewards
+
+    def seq2str(self, seq):
+        def token2str(t):
+            if (
+                t >= self.first_states_code
+                and t < self.first_states_code + self.nb_states_codes
+            ):
+                return "_#@T$"[t - self.first_states_code]
+            elif (
+                t >= self.first_actions_code
+                and t < self.first_actions_code + self.nb_actions_codes
+            ):
+                return "ISNEW"[t - self.first_actions_code]
+            elif (
+                t >= self.first_rewards_code
+                and t < self.first_rewards_code + self.nb_rewards_codes
+            ):
+                return "-0+"[t - self.first_rewards_code]
+            elif (
+                t >= self.first_lookahead_rewards_code
+                and t
+                < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
+            ):
+                return "n.pU"[t - self.first_lookahead_rewards_code]
+            else:
+                return "?"
+
+        return ["".join([token2str(x.item()) for x in row]) for row in seq]
+
+    ######################################################################
+
+    def episodes2str(
+        self,
+        lookahead_rewards,
+        states,
+        actions,
+        rewards,
+        unicode=False,
+        ansi_colors=False,
+    ):
+        if unicode:
+            symbols = "·█@T$"
+            # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
+            vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
         else:
-            return "?"
-
-    return ["".join([token2str(x.item()) for x in row]) for row in seq]
-
-
-######################################################################
-
-
-def episodes2str(
-    lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False
-):
-    if unicode:
-        symbols = "·█@T$"
-        # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
-        vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
-    else:
-        symbols = " #@T$"
-        vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
-
-    hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
-
-    result = hline
-
-    for n in range(states.size(0)):
+            symbols = " #@T$"
+            vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
+
+        hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
+
+        result = hline
+
+        for n in range(states.size(0)):
+
+            def state_symbol(v):
+                v = v.item()
+                return "?" if v < 0 or v >= len(symbols) else symbols[v]
+
+            for i in range(states.size(2)):
+                result += (
+                    vert
+                    + vert.join(
+                        [
+                            "".join([state_symbol(v) for v in row])
+                            for row in states[n, :, i]
+                        ]
+                    )
+                    + vert
+                    + "\n"
+                )
 
-        def state_symbol(v):
-            v = v.item()
-            return "?" if v < 0 or v >= len(symbols) else symbols[v]
+            # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
+
+            def status_bar(a, r, lr=None):
+                a, r = a.item(), r.item()
+                sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
+                sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
+                if lr is None:
+                    sb_lr = ""
+                else:
+                    lr = lr.item()
+                    sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
+                return (
+                    sb_a
+                    + "/"
+                    + sb_r
+                    + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
+                    + sb_lr
+                )
 
-        for i in range(states.size(2)):
             result += (
                 vert
                 + vert.join(
-                    ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]]
+                    [
+                        status_bar(a, r, lr)
+                        for a, r, lr in zip(
+                            actions[n], rewards[n], lookahead_rewards[n]
+                        )
+                    ]
                 )
                 + vert
                 + "\n"
             )
 
-        # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
+            result += hline
 
-        def status_bar(a, r, lr=None):
-            a, r = a.item(), r.item()
-            sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
-            sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
-            if lr is None:
-                sb_lr = ""
-            else:
-                lr = lr.item()
-                sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
-            return (
-                sb_a
-                + "/"
-                + sb_r
-                + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
-                + sb_lr
-            )
-
-        result += (
-            vert
-            + vert.join(
-                [
-                    status_bar(a, r, lr)
-                    for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n])
-                ]
-            )
-            + vert
-            + "\n"
-        )
-
-        result += hline
+        if ansi_colors:
+            for u, c in [("T", 31), ("@", 32), ("$", 34)]:
+                result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
 
-    if ansi_colors:
-        for u, c in [("T", 31), ("@", 32), ("$", 34)]:
-            result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+        return result
 
-    return result
+    ######################################################################
 
+    def save_seq_as_anim_script(self, seq, filename):
+        it_len = self.height * self.width + 3
 
-######################################################################
-
-
-def save_seq_as_anim_script(seq, filename):
-    it_len = height * width + 3
-
-    seq = (
-        seq.reshape(seq.size(0), -1, it_len)
-        .permute(1, 0, 2)
-        .reshape(T, seq.size(0), -1)
-    )
+        seq = (
+            seq.reshape(seq.size(0), -1, it_len)
+            .permute(1, 0, 2)
+            .reshape(self.T, seq.size(0), -1)
+        )
 
-    with open(filename, "w") as f:
-        for t in range(T):
-            f.write("clear\n")
-            f.write("cat << EOF\n")
-            # for i in range(seq.size(2)):
-            # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
-            lr, s, a, r = seq2episodes(
-                seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
-            )
-            f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-            f.write("EOF\n")
-            f.write("sleep 0.25\n")
-        print(f"Saved {filename}")
+        with open(filename, "w") as f:
+            for t in range(self.T):
+                # f.write("clear\n")
+                f.write("cat << EOF\n")
+                f.write("\u001b[H")
+                # for i in range(seq.size(2)):
+                # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], self.height, self.width)
+                lr, s, a, r = self.seq2episodes(seq[t : t + 1, :].reshape(8, -1))
+                f.write(self.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+                f.write("EOF\n")
+                f.write("sleep 0.25\n")
+            print(f"Saved {filename}")
 
 
 if __name__ == "__main__":
-    nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
-    states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
-    seq = episodes2seq(states, actions, rewards)
-    lr, s, a, r = seq2episodes(seq, height, width)
-    print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-
-    # print()
-    # for s in seq2str(seq):
-    # print(s)
-
-    nb, T = 50, 100
-    states, actions, rewards = generate_episodes(
-        nb=nb, height=height, width=width, T=T, nb_walls=3
-    )
-    seq = episodes2seq(states, actions, rewards)
-    save_seq_as_anim_script(seq, "anim.sh")
+    gw = GreedWorld(height=5, width=7, T=10, nb_walls=4, nb_coins=2)
+    states, actions, rewards = gw.generate_episodes(nb=6)
+    seq = gw.episodes2seq(states, actions, rewards)
+    lr, s, a, r = gw.seq2episodes(seq)
+    print(gw.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+
+    print()
+    for s in gw.seq2str(seq):
+        print(s)
+
+    gw = GreedWorld(height=5, width=7, T=100, nb_walls=4, nb_coins=2)
+    states, actions, rewards = gw.generate_episodes(nb=128)
+    seq = gw.episodes2seq(states, actions, rewards)
+    gw.save_seq_as_anim_script(seq, "anim.sh")
diff --git a/main.py b/main.py
index 2339dcf..0f2cb61 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -186,6 +186,8 @@ parser.add_argument("--greed_T", type=int, default=25)
 
 parser.add_argument("--greed_nb_walls", type=int, default=5)
 
+parser.add_argument("--greed_nb_coins", type=int, default=2)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -625,6 +627,7 @@ elif args.task == "greed":
         width=args.greed_width,
         T=args.greed_T,
         nb_walls=args.greed_nb_walls,
+        nb_coins=args.greed_nb_coins,
         logger=log_string,
         device=device,
     )
@@ -700,8 +703,6 @@ if args.task == "expr" and args.expr_input_file is not None:
 
 ######################################################################
 
-nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
-
 # Compute the entropy of the training tokens
 
 token_count = 0
@@ -770,7 +771,7 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 nb_samples_seen = 0
 
-if nb_epochs_finished >= nb_epochs:
+if nb_epochs_finished >= args.nb_epochs:
     task.produce_results(
         n_epoch=nb_epochs_finished,
         model=model,
@@ -781,7 +782,7 @@ if nb_epochs_finished >= nb_epochs:
 
 time_pred_result = None
 
-for n_epoch in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
     log_string(f"learning_rate {learning_rate}")
index 77c29ce..131c822 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -264,6 +264,7 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
+        # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
index aa5df72..6a7e639 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1880,6 +1880,7 @@ class Greed(Task):
         width,
         T,
         nb_walls,
+        nb_coins,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1887,32 +1888,24 @@ class Greed(Task):
 
         self.batch_size = batch_size
         self.device = device
-        self.height = height
-        self.width = width
 
-        states, actions, rewards = greed.generate_episodes(
-            nb_train_samples + nb_test_samples, height, width, T, nb_walls
+        self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
+
+        states, actions, rewards = self.world.generate_episodes(
+            nb_train_samples + nb_test_samples
         )
-        seq = greed.episodes2seq(states, actions, rewards)
-        # seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3]
+        seq = self.world.episodes2seq(states, actions, rewards)
         self.train_input = seq[:nb_train_samples].to(self.device)
         self.test_input = seq[nb_train_samples:].to(self.device)
 
-        self.state_len = self.height * self.width
-        self.index_lookahead_reward = 0
-        self.index_states = 1
-        self.index_action = self.state_len + 1
-        self.index_reward = self.state_len + 2
-        self.it_len = self.state_len + 3  # lookahead_reward / state / action / reward
-
     def wipe_lookahead_rewards(self, batch):
         t = torch.arange(batch.size(1), device=batch.device)[None, :]
         u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
         lr_mask = (t <= u).long() * (
-            t % self.it_len == self.index_lookahead_reward
+            t % self.world.it_len == self.world.index_lookahead_reward
         ).long()
 
-        return lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch
+        return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1927,7 +1920,7 @@ class Greed(Task):
             yield self.wipe_lookahead_rewards(batch)
 
     def vocabulary_size(self):
-        return greed.nb_codes
+        return self.world.nb_codes
 
     def thinking_autoregression(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
@@ -1954,14 +1947,16 @@ class Greed(Task):
 
         result = self.test_input[:250].clone()
         # Erase all the content but that of the first iteration
-        result[:, self.it_len :] = -1
+        result[:, self.world.it_len :] = -1
         # Set the lookahead_reward of the firs to UNKNOWN
-        result[:, self.index_lookahead_reward] = greed.lookahead_reward2code(2)
+        result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
+            2
+        )
 
         t = torch.arange(result.size(1), device=result.device)[None, :]
 
         for u in tqdm.tqdm(
-            range(0, result.size(1), self.it_len),
+            range(0, result.size(1), self.world.it_len),
             desc="thinking",
         ):
             # Generate the next state but keep the initial one, the
@@ -1969,31 +1964,35 @@ class Greed(Task):
             # UNKNOWN
             if u > 0:
                 result[
-                    :, u + self.index_lookahead_reward
-                ] = greed.lookahead_reward2code(2)
-                ar_mask = (t >= u + self.index_states).long() * (
-                    t < u + self.index_states + self.state_len
+                    :, u + self.world.index_lookahead_reward
+                ] = self.world.lookahead_reward2code(2)
+                ar_mask = (t >= u + self.world.index_states).long() * (
+                    t < u + self.world.index_states + self.world.state_len
                 ).long()
                 ar(result, ar_mask)
 
             # Generate the action and reward with lookahead_reward to +1
-            result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(1)
-            ar_mask = (t >= u + self.index_action).long() * (
-                t <= u + self.index_reward
+            result[
+                :, u + self.world.index_lookahead_reward
+            ] = self.world.lookahead_reward2code(1)
+            ar_mask = (t >= u + self.world.index_reward).long() * (
+                t <= u + self.world.index_action
             ).long()
             ar(result, ar_mask)
 
             # Set the lookahead_reward to UNKNOWN for the next iterations
-            result[:, u + self.index_lookahead_reward] = greed.lookahead_reward2code(2)
+            result[
+                :, u + self.world.index_lookahead_reward
+            ] = self.world.lookahead_reward2code(2)
 
         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
             for n in range(10):
                 for s in snapshots:
-                    lr, s, a, r = greed.seq2episodes(
-                        s[n : n + 1], self.height, self.width
+                    lr, s, a, r = self.world.seq2episodes(
+                        s[n : n + 1],
                     )
-                    str = greed.episodes2str(
+                    str = self.world.episodes2str(
                         lr, s, a, r, unicode=True, ansi_colors=True
                     )
                     f.write(str)
@@ -2001,8 +2000,8 @@ class Greed(Task):
 
         # Saving the generated sequences
 
-        lr, s, a, r = greed.seq2episodes(result, self.height, self.width)
-        str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+        lr, s, a, r = self.world.seq2episodes(result)
+        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
 
         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
@@ -2016,12 +2015,10 @@ class Greed(Task):
 
         # Saving the ground truth
 
-        lr, s, a, r = greed.seq2episodes(
+        lr, s, a, r = self.world.seq2episodes(
             result,
-            self.height,
-            self.width,
         )
-        str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
 
         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
@@ -2031,8 +2028,7 @@ class Greed(Task):
         # Re-generating from the first frame
 
         ar_mask = (
-            torch.arange(result.size(1), device=result.device)
-            >= self.height * self.width + 3
+            torch.arange(result.size(1), device=result.device) >= self.world.it_len
         ).long()[None, :]
         ar_mask = ar_mask.expand_as(result)
         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
@@ -2048,12 +2044,10 @@ class Greed(Task):
 
         # Saving the generated sequences
 
-        lr, s, a, r = greed.seq2episodes(
+        lr, s, a, r = self.world.seq2episodes(
             result,
-            self.height,
-            self.width,
         )
-        str = greed.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
+        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
 
         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
         with open(filename, "w") as f: