Update.
[picoclvr.git] / stack.py
index 312b39f..543f04e 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -13,71 +13,95 @@ import torch, torchvision
 # CODE_VAL=val + 2 * nb_stacks
 
 
-def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
+def generate_sequences(
+    nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
     k = torch.arange(nb)
-    result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
-    recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
+    result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
+    recorded_stack_counts = torch.zeros(
+        nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
+    )
 
     for t in range(nb_steps):
         op = torch.randint(2, (nb,))
         st = torch.randint(nb_stacks, (nb,))
         op = op * (stack_counts[k, st] > 0)
-        val_push = torch.randint(nb_values, (nb,))
+        if values is None:
+            val_push = torch.randint(10**nb_digits, (nb,))
+        else:
+            val_push = values[torch.randint(values.size(0), (nb,))]
         val_pop = stack[
             k,
             st,
             (stack_counts[k, st] - 1).clamp(min=0),
         ]
         stack[k, st, stack_counts[k, st]] = val_push
-        recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
+        recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
         stack_counts[k[op == 0], st[op == 0]] += 1
         stack_counts[k[op == 1], st[op == 1]] -= 1
-        result[:, 2 * t] = st * 2 + op
-        result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
+        result[:, (1 + nb_digits) * t] = st * 2 + op
+        for d in range(nb_digits):
+            result[:, (1 + nb_digits) * t + 1 + d] = (
+                (op * val_pop + (1 - op) * val_push) // (10**d)
+            ) % 10 + 2 * nb_stacks
 
     return result.to(device), recorded_stack_counts.to(device)
 
 
-def remove_poped_values(seq, nb_stacks):
+def remove_popped_values(seq, nb_stacks, nb_digits):
     m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
-    seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
+    for d in range(nb_digits):
+        k = d + 1
+        seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
 
 
-def seq_to_str(seq, recorded_stack_counts=None):
-    assert seq.size(0) % 2 == 0
+def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
+    assert seq.size(0) % (1 + nb_digits) == 0
     s = ""
-    for t in range(seq.size(0) // 2):
-        op = seq[2 * t]
-        op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
-        if seq[2 * t + 1] == -1:
-            val = "?"
-        else:
-            val = seq[2 * t + 1] - 2 * nb_stacks
+    for t in range(seq.size(0) // (1 + nb_digits)):
+        n_op = seq[(1 + nb_digits) * t]
         if t > 0:
             s += " "
         if recorded_stack_counts is not None:
-            s += f"[{recorded_stack_counts[2*t+1]}] "
-        s += f"{op} {val}"
+            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
+        s += f"POP" if n_op % 2 == 1 else f"PSH"
+        if nb_stacks > 1:
+            s += f"_{n_op//2}"
+        for d in range(nb_digits):
+            if seq[(1 + nb_digits) * t + 1 + d] == -1:
+                s += " ?"
+            else:
+                s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
     return s
 
 
 ######################################################################
 
 if __name__ == "__main__":
-    nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
+    nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
     seq, recorded_stack_counts = generate_sequences(
-        nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
+        nb=nb,
+        nb_steps=nb_steps,
+        nb_stacks=nb_stacks,
+        nb_digits=nb_digits,
     )
 
     for n in range(min(10, seq.size(0))):
-        # print(seq_to_str(seq[n], recorded_stack_counts[n]))
-        print(seq_to_str(seq[n]))
+        print(
+            seq_to_str(
+                seq[n],
+                nb_stacks=nb_stacks,
+                nb_digits=nb_digits,
+                recorded_stack_counts=recorded_stack_counts[n],
+            )
+        )
+        # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
 
-    print("--------------------------------------")
+    print("-- PREPARED FOR TEST -----------------")
 
-    remove_poped_values(seq, nb_stacks)
+    remove_popped_values(seq, nb_stacks, nb_digits)
 
     for n in range(min(10, seq.size(0))):
-        print(seq_to_str(seq[n]))
+        print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))