X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=3844161efc4460a7f396646bdc9362d6de3a1a47;hb=27dea0ab0448511236cb344c17162e84359a14ee;hp=458ec794bd7a33f32d967cb946a4c17d7e2db88f;hpb=4b512400504fcac28c0ee85c804571c688696b85;p=picoclvr.git diff --git a/stack.py b/stack.py index 458ec79..3844161 100755 --- a/stack.py +++ b/stack.py @@ -80,7 +80,7 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): ###################################################################### if __name__ == "__main__": - nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1 + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, @@ -88,8 +88,6 @@ if __name__ == "__main__": nb_digits=nb_digits, ) - print("-- TRAIN -----------------------------") - for n in range(min(10, seq.size(0))): # print(seq_to_str(seq[n], recorded_stack_counts[n])) print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))