X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;fp=stack.py;h=ba452aab0c600230d0e60b95da18e60cfee00c07;hb=c5daf2eeedb26a25789de370171d592c621a2fac;hp=312b39fda81b142b4dd9c10bc35259f8c8c7dc21;hpb=76671c582f029aa67fce2626764b02e8d9e2dbeb;p=picoclvr.git diff --git a/stack.py b/stack.py index 312b39f..ba452aa 100755 --- a/stack.py +++ b/stack.py @@ -45,12 +45,13 @@ def remove_poped_values(seq, nb_stacks): seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:] -def seq_to_str(seq, recorded_stack_counts=None): +def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None): assert seq.size(0) % 2 == 0 s = "" for t in range(seq.size(0) // 2): - op = seq[2 * t] - op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}" + n_op = seq[2 * t] + op = f"POP" if n_op % 2 == 1 else f"PSH" + if show_stack_nb: op+=f"_{n_op//2}" if seq[2 * t + 1] == -1: val = "?" else: @@ -71,13 +72,15 @@ if __name__ == "__main__": nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values ) + print("-- TRAIN -----------------------------") + for n in range(min(10, seq.size(0))): # print(seq_to_str(seq[n], recorded_stack_counts[n])) - print(seq_to_str(seq[n])) + print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1)) - print("--------------------------------------") + print("-- TEST ------------------------------") remove_poped_values(seq, nb_stacks) for n in range(min(10, seq.size(0))): - print(seq_to_str(seq[n])) + print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))