X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=3844161efc4460a7f396646bdc9362d6de3a1a47;hb=7d596f5ca37e7e9ac53047acdc7d00928821929f;hp=675182eb335f6fa467ffedbfe9bd70b821b1c30f;hpb=87da428a5ab9ac3cd49ab22bd27e572d0b16f29c;p=picoclvr.git diff --git a/stack.py b/stack.py index 675182e..3844161 100755 --- a/stack.py +++ b/stack.py @@ -13,7 +13,9 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")): +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) @@ -26,7 +28,10 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device(" op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_counts[k, st] > 0) - val_push = torch.randint(10**nb_digits, (nb,)) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] val_pop = stack[ k, st, @@ -75,7 +80,7 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): ###################################################################### if __name__ == "__main__": - nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1 + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, @@ -83,13 +88,11 @@ if __name__ == "__main__": nb_digits=nb_digits, ) - print("-- TRAIN -----------------------------") - for n in range(min(10, seq.size(0))): # print(seq_to_str(seq[n], recorded_stack_counts[n])) print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) - print("-- TEST ------------------------------") + print("-- PREPARED FOR TEST -----------------") remove_popped_values(seq, nb_stacks, nb_digits)