X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=19f918c5de7357a35e646e0217ccf068dc8f8890;hb=e781d77071fa26f393f50451f91c70f4a0850ca5;hp=e3fd9f0c37696a878b50f578d4d44c9141e1b0b2;hpb=a3211f96c7426a613b82a2de87d4dd70640e8f46;p=picoclvr.git diff --git a/main.py b/main.py index e3fd9f0..19f918c 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,17 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options +parser.add_argument("--sandbox_level", type=int, default=0) + +parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) + +parser.add_argument("--sandbox_levels_len_source", type=int, default=6) + +parser.add_argument("--sandbox_levels_len_result", type=int, default=8) + +############################## +# picoclvr options + parser.add_argument("--picoclvr_nb_colors", type=int, default=5) parser.add_argument("--picoclvr_height", type=int, default=12) @@ -152,9 +163,9 @@ if args.result_dir is None: default_args = { "sandbox": { - "nb_epochs": 10, + "nb_epochs": 50, "batch_size": 25, - "nb_train_samples": 25000, + "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { @@ -265,8 +276,28 @@ picoclvr_pruner_eval = ( ###################################################################### if args.task == "sandbox": + if args.sandbox_level == 0: + problem = tasks.ProblemLevel0( + nb_sentences=args.sandbox_levels_nb_items, + len_prompt=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 1: + problem = tasks.ProblemLevel1( + nb_operators=args.sandbox_levels_nb_items, + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 2: + problem = tasks.ProblemLevel2( + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + else: + raise ValueError(f"Unknown sandbox level {args.sandbox_level}") + task = tasks.SandBox( - tasks.ProblemLevel2(), + problem, # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples,