From: François Fleuret Date: Tue, 26 Mar 2024 15:52:58 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=picoclvr.git;a=commitdiff_plain;h=aa21f7edd3969ca509dfc1378fb5d1a1f7ebf9d1 Update. --- diff --git a/greed.py b/greed.py index 6b271b5..3cbe886 100755 --- a/greed.py +++ b/greed.py @@ -279,12 +279,44 @@ def episodes2str( ###################################################################### + +def save_seq_as_anim_script(seq, filename): + it_len = height * width + 3 + + seq = ( + seq.reshape(seq.size(0), -1, it_len) + .permute(1, 0, 2) + .reshape(T, seq.size(0), -1) + ) + + with open(filename, "w") as f: + for t in range(T): + f.write("clear\n") + f.write("cat << EOF\n") + # for i in range(seq.size(2)): + # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width) + lr, s, a, r = seq2episodes( + seq[t : t + 1, :].reshape(5, 10 * it_len), height, width + ) + f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + f.write("EOF\n") + f.write("sleep 0.5\n") + + if __name__ == "__main__": - nb, height, width, T, nb_walls = 5, 5, 7, 10, 5 + nb, height, width, T, nb_walls = 6, 5, 7, 10, 5 states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls) seq = episodes2seq(states, actions, rewards) lr, s, a, r = seq2episodes(seq, height, width) print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) + # print() # for s in seq2str(seq): # print(s) + + nb, T = 50, 100 + states, actions, rewards = generate_episodes( + nb=nb, height=height, width=width, T=T, nb_walls=3 + ) + seq = episodes2seq(states, actions, rewards) + save_seq_as_anim_script(seq, "anim.sh")