3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
10 ######################################################################
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
16 def generate(nb, nb_steps, nb_stacks, nb_values):
17 stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
18 stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
20 result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
21 depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
23 for t in range(nb_steps):
24 op = torch.randint(2, (nb,))
25 st = torch.randint(nb_stacks, (nb,))
26 op = op * (stack_pointers[k, st] > 0)
27 val_push = torch.randint(nb_values, (nb,))
31 (stack_pointers[k, st] - 1).clamp(min=0),
33 stack[k, st, stack_pointers[k, st]] = val_push
34 depth_counts[:, 2 * t + 1] = stack_pointers[k, st]
35 stack_pointers[k[op == 0], st[op == 0]] += 1
36 stack_pointers[k[op == 1], st[op == 1]] -= 1
37 result[:, 2 * t] = st * 2 + op
38 result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
40 return result, depth_counts
43 def seq_to_str(seq, depth_counts=None):
44 assert seq.size(0) % 2 == 0
46 for t in range(seq.size(0) // 2):
48 op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
49 val = seq[2 * t + 1] - 2 * nb_stacks
52 if depth_counts is not None:
53 s += f"[{depth_counts[2*t+1]}] "
58 ######################################################################
60 if __name__ == "__main__":
61 nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
62 seq, depth_counts = generate(
63 nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
66 for n in range(min(10, seq.size(0))):
67 print(seq_to_str(seq[n], depth_counts[n]))