Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index bb1e7b4..4770a12 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -37,7 +37,7 @@ parser.add_argument(
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
-parser.add_argument("--result_dir", type=str, default="results_default")
+parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
@@ -123,6 +123,8 @@ args = parser.parse_args()
 
 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
 
+if args.result_dir is None: args.result_dir=f"results_{args.task}"
+
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
@@ -968,7 +970,9 @@ class TaskStack(Task):
             )
 
             #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            input = self.test_input[:10, :50]
+            l = 50
+            l = l - l % (1 + self.nb_digits)
+            input = self.test_input[:10, :l]
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
             ar_mask = (result != input).long()