Update.
[picoclvr.git] / stack.py
index d3be4f8..312b39f 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -13,45 +13,52 @@ import torch, torchvision
 # CODE_VAL=val + 2 * nb_stacks
 
 
-def generate(nb, seq_len, nb_stacks, nb_values):
-    stack = torch.empty(nb, nb_stacks, seq_len, dtype=torch.int64)
-    stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
+    stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
+    stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
     k = torch.arange(nb)
-    result = torch.empty(nb, 2 * seq_len, dtype=torch.int64)
+    result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
+    recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
 
-    for t in range(seq_len):
+    for t in range(nb_steps):
         op = torch.randint(2, (nb,))
         st = torch.randint(nb_stacks, (nb,))
-        op = op * (stack_pointers[k, st] > 0)
+        op = op * (stack_counts[k, st] > 0)
         val_push = torch.randint(nb_values, (nb,))
-        # top_val[n,s]=stack[n,stack_pointers[n,s]]
-        top_values = stack[
+        val_pop = stack[
             k,
             st,
-            (stack_pointers[k, st] - 1).clamp(min=0),
+            (stack_counts[k, st] - 1).clamp(min=0),
         ]
-        stack[
-            k[:, None].expand_as(stack_pointers),
-            st[:, None].expand_as(stack_pointers),
-            stack_pointers,
-        ] = val_push[:, None].expand_as(stack_pointers)
-        stack_pointers[k[op == 0], st[op == 0]] += 1
-        stack_pointers[k[op == 1], st[op == 1]] -= 1
+        stack[k, st, stack_counts[k, st]] = val_push
+        recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
+        stack_counts[k[op == 0], st[op == 0]] += 1
+        stack_counts[k[op == 1], st[op == 1]] -= 1
         result[:, 2 * t] = st * 2 + op
-        result[:, 2 * t + 1] = (op * top_values + (1 - op) * val_push) + 2 * nb_stacks
+        result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
 
-    return result
+    return result.to(device), recorded_stack_counts.to(device)
 
 
-def seq_to_str(seq):
+def remove_poped_values(seq, nb_stacks):
+    m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
+    seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
+
+
+def seq_to_str(seq, recorded_stack_counts=None):
     assert seq.size(0) % 2 == 0
     s = ""
-    for t in range(0, seq.size(0), 2):
-        op = seq[t]
-        op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
-        val = seq[t + 1] - 2 * nb_stacks
+    for t in range(seq.size(0) // 2):
+        op = seq[2 * t]
+        op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
+        if seq[2 * t + 1] == -1:
+            val = "?"
+        else:
+            val = seq[2 * t + 1] - 2 * nb_stacks
         if t > 0:
             s += " "
+        if recorded_stack_counts is not None:
+            s += f"[{recorded_stack_counts[2*t+1]}] "
         s += f"{op} {val}"
     return s
 
@@ -59,7 +66,18 @@ def seq_to_str(seq):
 ######################################################################
 
 if __name__ == "__main__":
-    nb, seq_len, nb_stacks, nb_values = 3, 10, 1, 5
-    result = generate(nb=nb, seq_len=seq_len, nb_stacks=nb_stacks, nb_values=nb_values)
-    for n in range(result.size(0)):
-        print(seq_to_str(result[n]))
+    nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
+    seq, recorded_stack_counts = generate_sequences(
+        nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
+    )
+
+    for n in range(min(10, seq.size(0))):
+        # print(seq_to_str(seq[n], recorded_stack_counts[n]))
+        print(seq_to_str(seq[n]))
+
+    print("--------------------------------------")
+
+    remove_poped_values(seq, nb_stacks)
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n]))