X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=escape.py;h=aa87cddd728220933bcba5d137a36953ad4b08b8;hb=a85b2f79ae551da3454daa7b90554be5ca3c5bf6;hp=7fcafaa91fc8fc2b8f50732c5cf914f682cad6bf;hpb=22b841a39cc73310cd03dbd1d32fb387f68521d0;p=picoclvr.git diff --git a/escape.py b/escape.py index 7fcafaa..aa87cdd 100755 --- a/escape.py +++ b/escape.py @@ -150,14 +150,16 @@ def episodes2str(states, actions, rewards, unicode=False, ansi_colors=False): result = hline for n in range(states.size(0)): + + def state_symbol(v): + v = v.item() + return "?" if v < 0 or v >= len(symbols) else symbols[v] + for i in range(states.size(2)): result += ( vert + vert.join( - [ - "".join([symbols[v.item()] for v in row]) - for row in states[n, :, i] - ] + ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]] ) + vert + "\n" @@ -166,8 +168,9 @@ def episodes2str(states, actions, rewards, unicode=False, ansi_colors=False): result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n" def status_bar(a, r): - a = "ISNEW"[a.item()] - r = "" if r == 0 else f"{r.item()}" + a = a.item() + a = "ISNEW"[a] if a >= 0 and a < 5 else "?" + r = "?" if r < -1 or r > 2 else ("" if r == 0 else f"{r.item()}") return a + " " * (states.size(-1) - len(a) - len(r)) + r result += (