projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
stack.py
diff --git
a/stack.py
b/stack.py
index
675182e
..
543f04e
100755
(executable)
--- a/
stack.py
+++ b/
stack.py
@@
-13,7
+13,9
@@
import torch, torchvision
# CODE_VAL=val + 2 * nb_stacks
# CODE_VAL=val + 2 * nb_stacks
-def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
+def generate_sequences(
+ nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
@@
-26,14
+28,17
@@
def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
- val_push = torch.randint(10**nb_digits, (nb,))
+ if values is None:
+ val_push = torch.randint(10**nb_digits, (nb,))
+ else:
+ val_push = values[torch.randint(values.size(0), (nb,))]
val_pop = stack[
k,
st,
(stack_counts[k, st] - 1).clamp(min=0),
]
stack[k, st, stack_counts[k, st]] = val_push
val_pop = stack[
k,
st,
(stack_counts[k, st] - 1).clamp(min=0),
]
stack[k, st, stack_counts[k, st]] = val_push
- recorded_stack_counts[:, (1 + nb_digits) * t
+ 1
] = stack_counts[k, st]
+ recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
stack_counts[k[op == 0], st[op == 0]] += 1
stack_counts[k[op == 1], st[op == 1]] -= 1
result[:, (1 + nb_digits) * t] = st * 2 + op
stack_counts[k[op == 0], st[op == 0]] += 1
stack_counts[k[op == 1], st[op == 1]] -= 1
result[:, (1 + nb_digits) * t] = st * 2 + op
@@
-59,6
+64,8
@@
def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
n_op = seq[(1 + nb_digits) * t]
if t > 0:
s += " "
n_op = seq[(1 + nb_digits) * t]
if t > 0:
s += " "
+ if recorded_stack_counts is not None:
+ s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
s += f"POP" if n_op % 2 == 1 else f"PSH"
if nb_stacks > 1:
s += f"_{n_op//2}"
s += f"POP" if n_op % 2 == 1 else f"PSH"
if nb_stacks > 1:
s += f"_{n_op//2}"
@@
-67,15
+74,13
@@
def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
s += " ?"
else:
s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
s += " ?"
else:
s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
- if recorded_stack_counts is not None:
- s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] "
return s
######################################################################
if __name__ == "__main__":
return s
######################################################################
if __name__ == "__main__":
- nb, nb_steps, nb_stacks, nb_digits = 150000,
10, 1
, 1
+ nb, nb_steps, nb_stacks, nb_digits = 150000,
20, 2
, 1
seq, recorded_stack_counts = generate_sequences(
nb=nb,
nb_steps=nb_steps,
seq, recorded_stack_counts = generate_sequences(
nb=nb,
nb_steps=nb_steps,
@@
-83,13
+88,18
@@
if __name__ == "__main__":
nb_digits=nb_digits,
)
nb_digits=nb_digits,
)
- print("-- TRAIN -----------------------------")
-
for n in range(min(10, seq.size(0))):
for n in range(min(10, seq.size(0))):
- # print(seq_to_str(seq[n], recorded_stack_counts[n]))
- print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
-
- print("-- TEST ------------------------------")
+ print(
+ seq_to_str(
+ seq[n],
+ nb_stacks=nb_stacks,
+ nb_digits=nb_digits,
+ recorded_stack_counts=recorded_stack_counts[n],
+ )
+ )
+ # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
+
+ print("-- PREPARED FOR TEST -----------------")
remove_popped_values(seq, nb_stacks, nb_digits)
remove_popped_values(seq, nb_stacks, nb_digits)