X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=543f04e5628b7276769241922c3770afd5d416f9;hb=08b58304225e044a21419dd30302d985acc1824c;hp=675182eb335f6fa467ffedbfe9bd70b821b1c30f;hpb=87da428a5ab9ac3cd49ab22bd27e572d0b16f29c;p=picoclvr.git diff --git a/stack.py b/stack.py index 675182e..543f04e 100755 --- a/stack.py +++ b/stack.py @@ -13,7 +13,9 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")): +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) @@ -26,14 +28,17 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device(" op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_counts[k, st] > 0) - val_push = torch.randint(10**nb_digits, (nb,)) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] val_pop = stack[ k, st, (stack_counts[k, st] - 1).clamp(min=0), ] stack[k, st, stack_counts[k, st]] = val_push - recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st] + recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] stack_counts[k[op == 0], st[op == 0]] += 1 stack_counts[k[op == 1], st[op == 1]] -= 1 result[:, (1 + nb_digits) * t] = st * 2 + op @@ -59,6 +64,8 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): n_op = seq[(1 + nb_digits) * t] if t > 0: s += " " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " s += f"POP" if n_op % 2 == 1 else f"PSH" if nb_stacks > 1: s += f"_{n_op//2}" @@ -67,15 +74,13 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): s += " ?" else: s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" - if recorded_stack_counts is not None: - s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] " return s ###################################################################### if __name__ == "__main__": - nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1 + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, @@ -83,13 +88,18 @@ if __name__ == "__main__": nb_digits=nb_digits, ) - print("-- TRAIN -----------------------------") - for n in range(min(10, seq.size(0))): - # print(seq_to_str(seq[n], recorded_stack_counts[n])) - print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) - - print("-- TEST ------------------------------") + print( + seq_to_str( + seq[n], + nb_stacks=nb_stacks, + nb_digits=nb_digits, + recorded_stack_counts=recorded_stack_counts[n], + ) + ) + # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) + + print("-- PREPARED FOR TEST -----------------") remove_popped_values(seq, nb_stacks, nb_digits)