X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=escape.py;h=43843f0bfbea776410b42cccaf56c534a673e150;hb=62ad2378c60cdf322c0111279bd45fbef8365fc2;hp=1c1bc20470d969d93247cdad66b949f002455410;hpb=e39282eef52a7f5ab6654b999009127569b1b599;p=picoclvr.git diff --git a/escape.py b/escape.py index 1c1bc20..43843f0 100755 --- a/escape.py +++ b/escape.py @@ -25,7 +25,7 @@ nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes ###################################################################### -def generate_episodes(nb, height=6, width=6, T=10): +def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3): rnd = torch.rand(nb, height, width) rnd[:, 0, :] = 0 rnd[:, -1, :] = 0 @@ -33,11 +33,12 @@ def generate_episodes(nb, height=6, width=6, T=10): rnd[:, :, -1] = 0 wall = 0 - for k in range(3): + for k in range(nb_walls): wall = wall + ( rnd.flatten(1).argmax(dim=1)[:, None] == torch.arange(rnd.flatten(1).size(1))[None, :] ).long().reshape(rnd.size()) + rnd = rnd * (1 - wall.clamp(max=1)) states = wall[:, None, :, :].expand(-1, T, -1, -1).clone() @@ -280,8 +281,8 @@ def episodes2str( ###################################################################### if __name__ == "__main__": - nb, height, width, T = 25, 5, 7, 25 - states, actions, rewards = generate_episodes(nb, height, width, T) + nb, height, width, T, nb_walls = 25, 5, 7, 25, 5 + states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls) seq = episodes2seq(states, actions, rewards, lookahead_delta=T) s, a, r, lr = seq2episodes(seq, height, width, lookahead=True) print(episodes2str(s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True))