X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=312b39fda81b142b4dd9c10bc35259f8c8c7dc21;hb=76671c582f029aa67fce2626764b02e8d9e2dbeb;hp=dc494bb88d2428b0f173b7e59fbf0e0eea873226;hpb=16bf88f88bbab138c0dc33b4fbd2d88cf9db3ae5;p=picoclvr.git diff --git a/stack.py b/stack.py index dc494bb..312b39f 100755 --- a/stack.py +++ b/stack.py @@ -13,44 +13,52 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate(nb, nb_steps, nb_stacks, nb_values): +def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) - stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64) + stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64) - depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) for t in range(nb_steps): op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) - op = op * (stack_pointers[k, st] > 0) + op = op * (stack_counts[k, st] > 0) val_push = torch.randint(nb_values, (nb,)) val_pop = stack[ k, st, - (stack_pointers[k, st] - 1).clamp(min=0), + (stack_counts[k, st] - 1).clamp(min=0), ] - stack[k, st, stack_pointers[k, st]] = val_push - depth_counts[:, 2 * t + 1] = stack_pointers[k, st] - stack_pointers[k[op == 0], st[op == 0]] += 1 - stack_pointers[k[op == 1], st[op == 1]] -= 1 + stack[k, st, stack_counts[k, st]] = val_push + recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st] + stack_counts[k[op == 0], st[op == 0]] += 1 + stack_counts[k[op == 1], st[op == 1]] -= 1 result[:, 2 * t] = st * 2 + op result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks - return result, depth_counts + return result.to(device), recorded_stack_counts.to(device) -def seq_to_str(seq, depth_counts=None): +def remove_poped_values(seq, nb_stacks): + m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() + seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:] + + +def seq_to_str(seq, recorded_stack_counts=None): assert seq.size(0) % 2 == 0 s = "" for t in range(seq.size(0) // 2): op = seq[2 * t] - op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}" - val = seq[2 * t + 1] - 2 * nb_stacks + op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}" + if seq[2 * t + 1] == -1: + val = "?" + else: + val = seq[2 * t + 1] - 2 * nb_stacks if t > 0: s += " " - if depth_counts is not None: - s += f"[{depth_counts[2*t+1]}] " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[2*t+1]}] " s += f"{op} {val}" return s @@ -59,9 +67,17 @@ def seq_to_str(seq, depth_counts=None): if __name__ == "__main__": nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5 - seq, depth_counts = generate( + seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values ) for n in range(min(10, seq.size(0))): - print(seq_to_str(seq[n], depth_counts[n])) + # print(seq_to_str(seq[n], recorded_stack_counts[n])) + print(seq_to_str(seq[n])) + + print("--------------------------------------") + + remove_poped_values(seq, nb_stacks) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n]))