Update.
[picoclvr.git] / stack.py
index 675182e..3844161 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -13,7 +13,9 @@ import torch, torchvision
 # CODE_VAL=val + 2 * nb_stacks
 
 
-def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
+def generate_sequences(
+    nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
     k = torch.arange(nb)
@@ -26,7 +28,10 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("
         op = torch.randint(2, (nb,))
         st = torch.randint(nb_stacks, (nb,))
         op = op * (stack_counts[k, st] > 0)
-        val_push = torch.randint(10**nb_digits, (nb,))
+        if values is None:
+            val_push = torch.randint(10**nb_digits, (nb,))
+        else:
+            val_push = values[torch.randint(values.size(0), (nb,))]
         val_pop = stack[
             k,
             st,
@@ -75,7 +80,7 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
 ######################################################################
 
 if __name__ == "__main__":
-    nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1
+    nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
     seq, recorded_stack_counts = generate_sequences(
         nb=nb,
         nb_steps=nb_steps,
@@ -83,13 +88,11 @@ if __name__ == "__main__":
         nb_digits=nb_digits,
     )
 
-    print("-- TRAIN -----------------------------")
-
     for n in range(min(10, seq.size(0))):
         # print(seq_to_str(seq[n], recorded_stack_counts[n]))
         print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
 
-    print("-- TEST ------------------------------")
+    print("-- PREPARED FOR TEST -----------------")
 
     remove_popped_values(seq, nb_stacks, nb_digits)