##############################
# Snake options
-parser.add_argument("--stack_nb_steps", type=int, default=25)
+parser.add_argument("--stack_nb_steps", type=int, default=100)
parser.add_argument("--stack_nb_stacks", type=int, default=1)
"nb_test_samples": 10000,
},
"stack": {
- "nb_epochs": 25,
+ "nb_epochs": 5,
"batch_size": 25,
- "nb_train_samples": 10000,
+ "nb_train_samples": 100000,
"nb_test_samples": 1000,
},
}
nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
)
+ mask = self.test_input.clone()
+ stack.remove_poped_values(mask,self.nb_stacks)
+ mask=(mask!=self.test_input)
+ counts = self.test_stack_counts.flatten()[mask.flatten()]
+ counts=F.one_hot(counts).sum(0)
+ log_string(f"stack_count {counts}")
+
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def batches(self, split="train", nb_to_use=-1, desc=None):
seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
-def seq_to_str(seq, recorded_stack_counts=None):
+def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None):
assert seq.size(0) % 2 == 0
s = ""
for t in range(seq.size(0) // 2):
- op = seq[2 * t]
- op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
+ n_op = seq[2 * t]
+ op = f"POP" if n_op % 2 == 1 else f"PSH"
+ if show_stack_nb: op+=f"_{n_op//2}"
if seq[2 * t + 1] == -1:
val = "?"
else:
nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
)
+ print("-- TRAIN -----------------------------")
+
for n in range(min(10, seq.size(0))):
# print(seq_to_str(seq[n], recorded_stack_counts[n]))
- print(seq_to_str(seq[n]))
+ print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))
- print("--------------------------------------")
+ print("-- TEST ------------------------------")
remove_poped_values(seq, nb_stacks)
for n in range(min(10, seq.size(0))):
- print(seq_to_str(seq[n]))
+ print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))