# CODE_VAL=val + 2 * nb_stacks
-def generate(nb, seq_len, nb_stacks, nb_values):
- stack = torch.empty(nb, nb_stacks, seq_len, dtype=torch.int64)
+def generate(nb, nb_steps, nb_stacks, nb_values):
+ stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
- result = torch.empty(nb, 2 * seq_len, dtype=torch.int64)
+ result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
+ depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
- for t in range(seq_len):
+ for t in range(nb_steps):
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_pointers[k, st] > 0)
val_push = torch.randint(nb_values, (nb,))
- # top_val[n,s]=stack[n,stack_pointers[n,s]]
- top_values = stack[
+ val_pop = stack[
k,
st,
(stack_pointers[k, st] - 1).clamp(min=0),
]
- stack[
- k[:, None].expand_as(stack_pointers),
- st[:, None].expand_as(stack_pointers),
- stack_pointers,
- ] = val_push[:, None].expand_as(stack_pointers)
+ stack[k, st, stack_pointers[k, st]] = val_push
+ depth_counts[:, 2 * t + 1] = stack_pointers[k, st]
stack_pointers[k[op == 0], st[op == 0]] += 1
stack_pointers[k[op == 1], st[op == 1]] -= 1
result[:, 2 * t] = st * 2 + op
- result[:, 2 * t + 1] = (op * top_values + (1 - op) * val_push) + 2 * nb_stacks
+ result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
- return result
+ return result, depth_counts
-def seq_to_str(seq):
+def seq_to_str(seq, depth_counts=None):
assert seq.size(0) % 2 == 0
s = ""
- for t in range(0, seq.size(0), 2):
- op = seq[t]
+ for t in range(seq.size(0) // 2):
+ op = seq[2 * t]
op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
- val = seq[t + 1] - 2 * nb_stacks
+ val = seq[2 * t + 1] - 2 * nb_stacks
if t > 0:
s += " "
+ if depth_counts is not None:
+ s += f"[{depth_counts[2*t+1]}] "
s += f"{op} {val}"
return s
######################################################################
if __name__ == "__main__":
- nb, seq_len, nb_stacks, nb_values = 3, 10, 1, 5
- result = generate(nb=nb, seq_len=seq_len, nb_stacks=nb_stacks, nb_values=nb_values)
- for n in range(result.size(0)):
- print(seq_to_str(result[n]))
+ nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
+ seq, depth_counts = generate(
+ nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
+ )
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], depth_counts[n]))