X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=greed.py;h=dc11d146772153e79b3cc9e537de54652621ecc2;hb=888106500badae460e9ae2183512c7124601acad;hp=6b271b55b414aefcc6a10fa3a674a9d1b62ec3bc;hpb=5cb7fa9e00bdcf59a4a50bd7deefec416e87fe43;p=culture.git diff --git a/greed.py b/greed.py index 6b271b5..dc11d14 100755 --- a/greed.py +++ b/greed.py @@ -77,6 +77,8 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): rnd = rnd * (1 - wall.clamp(max=1)) rnd = torch.rand(nb, height, width) + rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting + # position coins = torch.zeros(nb, T, height, width, dtype=torch.int64) rnd = rnd * (1 - wall.clamp(max=1)) for k in range(nb_coins): @@ -279,12 +281,45 @@ def episodes2str( ###################################################################### + +def save_seq_as_anim_script(seq, filename): + it_len = height * width + 3 + + seq = ( + seq.reshape(seq.size(0), -1, it_len) + .permute(1, 0, 2) + .reshape(T, seq.size(0), -1) + ) + + with open(filename, "w") as f: + for t in range(T): + f.write("clear\n") + f.write("cat << EOF\n") + # for i in range(seq.size(2)): + # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width) + lr, s, a, r = seq2episodes( + seq[t : t + 1, :].reshape(5, 10 * it_len), height, width + ) + f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + f.write("EOF\n") + f.write("sleep 0.25\n") + print(f"Saved {filename}") + + if __name__ == "__main__": - nb, height, width, T, nb_walls = 5, 5, 7, 10, 5 + nb, height, width, T, nb_walls = 6, 5, 7, 10, 5 states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls) seq = episodes2seq(states, actions, rewards) lr, s, a, r = seq2episodes(seq, height, width) print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + # print() # for s in seq2str(seq): # print(s) + + nb, T = 50, 100 + states, actions, rewards = generate_episodes( + nb=nb, height=height, width=width, T=T, nb_walls=3 + ) + seq = episodes2seq(states, actions, rewards) + save_seq_as_anim_script(seq, "anim.sh")