Update.
[picoclvr.git] / stack.py
index 675182e..219a1ad 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -13,7 +13,9 @@ import torch, torchvision
 # CODE_VAL=val + 2 * nb_stacks
 
 
-def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
+def generate_sequences(
+    nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
     k = torch.arange(nb)
@@ -26,7 +28,10 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("
         op = torch.randint(2, (nb,))
         st = torch.randint(nb_stacks, (nb,))
         op = op * (stack_counts[k, st] > 0)
-        val_push = torch.randint(10**nb_digits, (nb,))
+        if values is None:
+            val_push = torch.randint(10**nb_digits, (nb,))
+        else:
+            val_push = values[torch.randint(values.size(0), (nb,))]
         val_pop = stack[
             k,
             st,