Update.
[picoclvr.git] / greed.py
index f7b4cf7..20cef79 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -94,9 +94,9 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
     agent_actions = torch.randint(5, (nb, T))
     rewards = torch.zeros(nb, T, dtype=torch.int64)
 
-    monster = torch.zeros(states.size(), dtype=torch.int64)
-    monster[:, 0, -1, -1] = 1
-    monster_actions = torch.randint(5, (nb, T))
+    troll = torch.zeros(states.size(), dtype=torch.int64)
+    troll[:, 0, -1, -1] = 1
+    troll_actions = torch.randint(5, (nb, T))
 
     all_moves = agent.new(nb, 5, height, width)
     for t in range(T - 1):
@@ -109,7 +109,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
         a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
         after_move = (all_moves * a).sum(dim=1)
         collision = (
-            (after_move * (1 - wall) * (1 - monster[:, t]))
+            (after_move * (1 - wall) * (1 - troll[:, t]))
             .flatten(1)
             .sum(dim=1)[:, None, None]
             == 0
@@ -117,12 +117,12 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
         agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
 
         all_moves.zero_()
-        all_moves[:, 0] = monster[:, t]
-        all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
-        all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
-        all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
-        all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
-        a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
+        all_moves[:, 0] = troll[:, t]
+        all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+        all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+        all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+        all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+        a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
         after_move = (all_moves * a).sum(dim=1)
         collision = (
             (after_move * (1 - wall) * (1 - agent[:, t + 1]))
@@ -130,13 +130,13 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
             .sum(dim=1)[:, None, None]
             == 0
         ).long()
-        monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
+        troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
 
         hit = (
-            (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+            (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
         )
         hit = (hit > 0).long()
 
@@ -147,7 +147,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
 
         rewards[:, t + 1] = -hit + (1 - hit) * got_coin
 
-    states = states + 2 * agent + 3 * monster + 4 * coins
+    states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
 
     return states, agent_actions, rewards
 
@@ -271,7 +271,7 @@ def episodes2str(
         result += hline
 
     if ansi_colors:
-        for u, c in [("$", 31), ("@", 32)]:
+        for u, c in [("T", 31), ("@", 32), ("$", 34)]:
             result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
 
     return result
@@ -279,12 +279,45 @@ def episodes2str(
 
 ######################################################################
 
+
+def save_seq_as_anim_script(seq, filename):
+    it_len = height * width + 3
+
+    seq = (
+        seq.reshape(seq.size(0), -1, it_len)
+        .permute(1, 0, 2)
+        .reshape(T, seq.size(0), -1)
+    )
+
+    with open(filename, "w") as f:
+        for t in range(T):
+            f.write("clear\n")
+            f.write("cat << EOF\n")
+            # for i in range(seq.size(2)):
+            # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
+            lr, s, a, r = seq2episodes(
+                seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
+            )
+            f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+            f.write("EOF\n")
+            f.write("sleep 0.25\n")
+        print(f"Saved {filename}")
+
+
 if __name__ == "__main__":
-    nb, height, width, T, nb_walls = 5, 5, 7, 10, 5
+    nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
     states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
     seq = episodes2seq(states, actions, rewards)
     lr, s, a, r = seq2episodes(seq, height, width)
     print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
+
     # print()
     # for s in seq2str(seq):
     # print(s)
+
+    nb, T = 50, 100
+    states, actions, rewards = generate_episodes(
+        nb=nb, height=height, width=width, T=T, nb_walls=3
+    )
+    seq = episodes2seq(states, actions, rewards)
+    save_seq_as_anim_script(seq, "anim.sh")