From c5daf2eeedb26a25789de370171d592c621a2fac Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 1 Jul 2023 20:50:15 +0200 Subject: [PATCH] Update. --- main.py | 13 ++++++++++--- stack.py | 15 +++++++++------ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 0323d02..14b1bc3 100755 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ parser.add_argument("--snake_length", type=int, default=200) ############################## # Snake options -parser.add_argument("--stack_nb_steps", type=int, default=25) +parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) @@ -166,9 +166,9 @@ default_args = { "nb_test_samples": 10000, }, "stack": { - "nb_epochs": 25, + "nb_epochs": 5, "batch_size": 25, - "nb_train_samples": 10000, + "nb_train_samples": 100000, "nb_test_samples": 1000, }, } @@ -892,6 +892,13 @@ class TaskStack(Task): nb_test_samples, nb_steps, nb_stacks, nb_values, self.device ) + mask = self.test_input.clone() + stack.remove_poped_values(mask,self.nb_stacks) + mask=(mask!=self.test_input) + counts = self.test_stack_counts.flatten()[mask.flatten()] + counts=F.one_hot(counts).sum(0) + log_string(f"stack_count {counts}") + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): diff --git a/stack.py b/stack.py index 312b39f..ba452aa 100755 --- a/stack.py +++ b/stack.py @@ -45,12 +45,13 @@ def remove_poped_values(seq, nb_stacks): seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:] -def seq_to_str(seq, recorded_stack_counts=None): +def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None): assert seq.size(0) % 2 == 0 s = "" for t in range(seq.size(0) // 2): - op = seq[2 * t] - op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}" + n_op = seq[2 * t] + op = f"POP" if n_op % 2 == 1 else f"PSH" + if show_stack_nb: op+=f"_{n_op//2}" if seq[2 * t + 1] == -1: val = "?" else: @@ -71,13 +72,15 @@ if __name__ == "__main__": nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values ) + print("-- TRAIN -----------------------------") + for n in range(min(10, seq.size(0))): # print(seq_to_str(seq[n], recorded_stack_counts[n])) - print(seq_to_str(seq[n])) + print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1)) - print("--------------------------------------") + print("-- TEST ------------------------------") remove_poped_values(seq, nb_stacks) for n in range(min(10, seq.size(0))): - print(seq_to_str(seq[n])) + print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1)) -- 2.39.5