Update.
[picoclvr.git] / escape.py
index a3d8c85..5f34cd1 100755 (executable)
--- a/escape.py
+++ b/escape.py
@@ -14,7 +14,7 @@ from torch.nn import functional as F
 nb_states_codes = 5
 nb_actions_codes = 5
 nb_rewards_codes = 3
-nb_lookahead_rewards_codes = 3
+nb_lookahead_rewards_codes = 4  # stands for -1, 0, +1, and UNKNOWN
 
 first_states_code = 0
 first_actions_code = first_states_code + nb_states_codes
@@ -50,6 +50,7 @@ def code2reward(r):
 
 
 def lookahead_reward2code(r):
+    # -1, 0, +1 or 2 for UNKNOWN
     return r + 1 + first_lookahead_rewards_code
 
 
@@ -60,7 +61,7 @@ def code2lookahead_reward(r):
 ######################################################################
 
 
-def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=3):
+def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
     rnd = torch.rand(nb, height, width)
     rnd[:, 0, :] = 0
     rnd[:, -1, :] = 0
@@ -195,7 +196,7 @@ def seq2str(seq):
             t >= first_lookahead_rewards_code
             and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
         ):
-            return "n.p"[t - first_lookahead_rewards_code]
+            return "n.pU"[t - first_lookahead_rewards_code]
         else:
             return "?"
 
@@ -241,7 +242,7 @@ def episodes2str(
         def status_bar(a, r, lr=None):
             a, r = a.item(), r.item()
             sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
-            sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
+            sb_r = "- +U"[r + 1] if r in {-1, 0, 1, 2} else "?"
             if lr is None:
                 sb_lr = ""
             else: