X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;fp=stack.py;h=219a1ad013692e49429d85866950b1afe7e5cbf6;hb=4502a109727b0424ff6d4df90f17b361524f9e73;hp=675182eb335f6fa467ffedbfe9bd70b821b1c30f;hpb=87da428a5ab9ac3cd49ab22bd27e572d0b16f29c;p=picoclvr.git diff --git a/stack.py b/stack.py index 675182e..219a1ad 100755 --- a/stack.py +++ b/stack.py @@ -13,7 +13,9 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")): +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) @@ -26,7 +28,10 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device(" op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_counts[k, st] > 0) - val_push = torch.randint(10**nb_digits, (nb,)) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] val_pop = stack[ k, st,