Oups
[picoclvr.git] / stack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 ######################################################################
11
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
14
15
16 def generate_sequences(
17     nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
18 ):
19     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
20     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
21     k = torch.arange(nb)
22     result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
23     recorded_stack_counts = torch.zeros(
24         nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
25     )
26
27     for t in range(nb_steps):
28         op = torch.randint(2, (nb,))
29         st = torch.randint(nb_stacks, (nb,))
30         op = op * (stack_counts[k, st] > 0)
31         if values is None:
32             val_push = torch.randint(10**nb_digits, (nb,))
33         else:
34             val_push = values[torch.randint(values.size(0), (nb,))]
35         val_pop = stack[
36             k,
37             st,
38             (stack_counts[k, st] - 1).clamp(min=0),
39         ]
40         stack[k, st, stack_counts[k, st]] = val_push
41         recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
42         stack_counts[k[op == 0], st[op == 0]] += 1
43         stack_counts[k[op == 1], st[op == 1]] -= 1
44         result[:, (1 + nb_digits) * t] = st * 2 + op
45         for d in range(nb_digits):
46             result[:, (1 + nb_digits) * t + 1 + d] = (
47                 (op * val_pop + (1 - op) * val_push) // (10**d)
48             ) % 10 + 2 * nb_stacks
49
50     return result.to(device), recorded_stack_counts.to(device)
51
52
53 def remove_popped_values(seq, nb_stacks, nb_digits):
54     m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
55     for d in range(nb_digits):
56         k = d + 1
57         seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
58
59
60 def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
61     assert seq.size(0) % (1 + nb_digits) == 0
62     s = ""
63     for t in range(seq.size(0) // (1 + nb_digits)):
64         n_op = seq[(1 + nb_digits) * t]
65         if t > 0:
66             s += " "
67         if recorded_stack_counts is not None:
68             s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
69         s += f"POP" if n_op % 2 == 1 else f"PSH"
70         if nb_stacks > 1:
71             s += f"_{n_op//2}"
72         for d in range(nb_digits):
73             if seq[(1 + nb_digits) * t + 1 + d] == -1:
74                 s += " ?"
75             else:
76                 s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
77     return s
78
79
80 ######################################################################
81
82 if __name__ == "__main__":
83     nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
84     seq, recorded_stack_counts = generate_sequences(
85         nb=nb,
86         nb_steps=nb_steps,
87         nb_stacks=nb_stacks,
88         nb_digits=nb_digits,
89     )
90
91     for n in range(min(10, seq.size(0))):
92         print(
93             seq_to_str(
94                 seq[n],
95                 nb_stacks=nb_stacks,
96                 nb_digits=nb_digits,
97                 recorded_stack_counts=recorded_stack_counts[n],
98             )
99         )
100         # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
101
102     print("-- PREPARED FOR TEST -----------------")
103
104     remove_popped_values(seq, nb_stacks, nb_digits)
105
106     for n in range(min(10, seq.size(0))):
107         print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))