3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
10 ######################################################################
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
16 def generate_sequences(
17 nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
19 stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
20 stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
22 result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
23 recorded_stack_counts = torch.zeros(
24 nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
27 for t in range(nb_steps):
28 op = torch.randint(2, (nb,)) # what operation (push/pop)
29 st = torch.randint(nb_stacks, (nb,)) # on what stack
30 op = op * (stack_counts[k, st] > 0) # can only push is stack is empty
32 if values is None: # we can use all the values
33 val_push = torch.randint(10**nb_digits, (nb,))
34 else: # values are constrained (e.g. to have train/test values disjoint)
35 val_push = values[torch.randint(values.size(0), (nb,))]
37 val_pop = stack[ # if we were popping, what value would that be?
40 (stack_counts[k, st] - 1).clamp(min=0), # deal with empty stack
43 # we always push the value, but it will be lost if we pop
44 # since we will move the count down
45 stack[k, st, stack_counts[k, st]] = val_push
46 recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
48 # we increase the stack count only when we actually push
49 stack_counts[k[op == 0], st[op == 0]] += 1
50 stack_counts[k[op == 1], st[op == 1]] -= 1
52 # add the operation number to the sequence, that incude the stack number
53 result[:, (1 + nb_digits) * t] = st * 2 + op
55 # add the digits to the sequence
56 for d in range(nb_digits):
57 result[:, (1 + nb_digits) * t + 1 + d] = (
58 (op * val_pop + (1 - op) * val_push) // (10**d)
59 ) % 10 + 2 * nb_stacks
61 return result.to(device), recorded_stack_counts.to(device)
64 def remove_popped_values(seq, nb_stacks, nb_digits):
65 m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
66 for d in range(nb_digits):
68 seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
71 def seq_to_str(seq, nb_stacks, nb_digits):
75 elif n < 2 * nb_stacks:
76 s = f"POP" if n % 2 == 1 else f"PSH"
80 elif n < 2 * nb_stacks + 10:
81 return f"{n - 2 * nb_stacks}"
85 return " ".join([n_to_str(x.item()) for x in seq])
88 ######################################################################
90 if __name__ == "__main__":
91 seq, recorded_stack_counts = generate_sequences(
98 sep = torch.full((seq.size(0), 1), seq.max() + 1)
100 seq = torch.cat([seq, sep, seq], dim=1)
102 for n in range(min(10, seq.size(0))):
103 print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
105 remove_popped_values(seq, 3, 3)
109 for n in range(min(10, seq.size(0))):
110 print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
114 nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
115 seq, recorded_stack_counts = generate_sequences(
122 for n in range(min(10, seq.size(0))):
128 recorded_stack_counts=recorded_stack_counts[n],
131 # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
133 print("-- PREPARED FOR TEST -----------------")
135 print("SANITY", seq.size())
137 remove_popped_values(seq, nb_stacks, nb_digits)
139 for n in range(min(10, seq.size(0))):
140 print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))