projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
stack.py
diff --git
a/stack.py
b/stack.py
index
675182e
..
458ec79
100755
(executable)
--- a/
stack.py
+++ b/
stack.py
@@
-13,7
+13,9
@@
import torch, torchvision
# CODE_VAL=val + 2 * nb_stacks
# CODE_VAL=val + 2 * nb_stacks
-def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
+def generate_sequences(
+ nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
@@
-26,7
+28,10
@@
def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
- val_push = torch.randint(10**nb_digits, (nb,))
+ if values is None:
+ val_push = torch.randint(10**nb_digits, (nb,))
+ else:
+ val_push = values[torch.randint(values.size(0), (nb,))]
val_pop = stack[
k,
st,
val_pop = stack[
k,
st,
@@
-89,7
+94,7
@@
if __name__ == "__main__":
# print(seq_to_str(seq[n], recorded_stack_counts[n]))
print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
# print(seq_to_str(seq[n], recorded_stack_counts[n]))
print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
- print("--
TEST -------------
-----------------")
+ print("--
PREPARED FOR TEST
-----------------")
remove_popped_values(seq, nb_stacks, nb_digits)
remove_popped_values(seq, nb_stacks, nb_digits)