Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 12:00:24 +0000 (13:00 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 12:00:24 +0000 (13:00 +0100)
greed.py

index f7b4cf7..6b271b5 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -94,9 +94,9 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
     agent_actions = torch.randint(5, (nb, T))
     rewards = torch.zeros(nb, T, dtype=torch.int64)
 
-    monster = torch.zeros(states.size(), dtype=torch.int64)
-    monster[:, 0, -1, -1] = 1
-    monster_actions = torch.randint(5, (nb, T))
+    troll = torch.zeros(states.size(), dtype=torch.int64)
+    troll[:, 0, -1, -1] = 1
+    troll_actions = torch.randint(5, (nb, T))
 
     all_moves = agent.new(nb, 5, height, width)
     for t in range(T - 1):
@@ -109,7 +109,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
         a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
         after_move = (all_moves * a).sum(dim=1)
         collision = (
-            (after_move * (1 - wall) * (1 - monster[:, t]))
+            (after_move * (1 - wall) * (1 - troll[:, t]))
             .flatten(1)
             .sum(dim=1)[:, None, None]
             == 0
@@ -117,12 +117,12 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
         agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
 
         all_moves.zero_()
-        all_moves[:, 0] = monster[:, t]
-        all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
-        all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
-        all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
-        all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
-        a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
+        all_moves[:, 0] = troll[:, t]
+        all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
+        all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
+        all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
+        all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
+        a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
         after_move = (all_moves * a).sum(dim=1)
         collision = (
             (after_move * (1 - wall) * (1 - agent[:, t + 1]))
@@ -130,13 +130,13 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
             .sum(dim=1)[:, None, None]
             == 0
         ).long()
-        monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
+        troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
 
         hit = (
-            (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
-            + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+            (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+            + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
         )
         hit = (hit > 0).long()
 
@@ -147,7 +147,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
 
         rewards[:, t + 1] = -hit + (1 - hit) * got_coin
 
-    states = states + 2 * agent + 3 * monster + 4 * coins
+    states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
 
     return states, agent_actions, rewards
 
@@ -271,7 +271,7 @@ def episodes2str(
         result += hline
 
     if ansi_colors:
-        for u, c in [("$", 31), ("@", 32)]:
+        for u, c in [("T", 31), ("@", 32), ("$", 34)]:
             result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
 
     return result