X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=4770a125172cf77a41aada959ef5f18b44517be4;hb=757876d57637e0da35f3680ec6ac9573b91f902a;hp=bb1e7b46524f0d170e233e33c3335965e5588636;hpb=4502a109727b0424ff6d4df90f17b361524f9e73;p=picoclvr.git diff --git a/main.py b/main.py index bb1e7b4..4770a12 100755 --- a/main.py +++ b/main.py @@ -37,7 +37,7 @@ parser.add_argument( parser.add_argument("--log_filename", type=str, default="train.log", help=" ") -parser.add_argument("--result_dir", type=str, default="results_default") +parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) @@ -123,6 +123,8 @@ args = parser.parse_args() assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +if args.result_dir is None: args.result_dir=f"results_{args.task}" + try: os.mkdir(args.result_dir) except FileExistsError: @@ -968,7 +970,9 @@ class TaskStack(Task): ) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - input = self.test_input[:10, :50] + l = 50 + l = l - l % (1 + self.nb_digits) + input = self.test_input[:10, :l] result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long()