X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=9a3d34633bd88cff4ac1e05ee77989e4d725e7b1;hb=960c93d7c0aea41d180814c46d3a05686a426764;hp=af94979e937d852912dc79c01ada589436de461f;hpb=d2844d7a2d09ef38dc6f62d5e131059cccc872c5;p=picoclvr.git diff --git a/main.py b/main.py index af94979..9a3d346 100755 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ parser.add_argument( "--task", type=str, default="sandbox", - help="sandbox, picoclvr, mnist, maze, snake, stack, expr, rpl, world", + help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -42,6 +42,10 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + parser.add_argument("--nb_epochs", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) @@ -56,6 +60,8 @@ parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") +######################################## + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -70,6 +76,8 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) +######################################## + parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--no_checkpoint", action="store_true", default=False) @@ -91,17 +99,6 @@ parser.add_argument("--rpl_nb_runs", type=int, default=8) parser.add_argument("--rpl_no_prog", action="store_true", default=False) -############################## -# sandbox options - -parser.add_argument("--sandbox_level", type=int, default=0) - -parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) - -parser.add_argument("--sandbox_levels_len_source", type=int, default=6) - -parser.add_argument("--sandbox_levels_len_result", type=int, default=8) - ############################## # picoclvr options @@ -333,31 +330,52 @@ picoclvr_pruner_eval = ( ###################################################################### -if args.task == "sandbox": - if args.sandbox_level == 0: - problem = problems.ProblemLevel0( - nb_sentences=args.sandbox_levels_nb_items, - len_prompt=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - elif args.sandbox_level == 1: - problem = problems.ProblemLevel1( - nb_operators=args.sandbox_levels_nb_items, - len_source=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - elif args.sandbox_level == 2: - problem = problems.ProblemLevel2( - len_source=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - else: - raise ValueError(f"Unknown sandbox level {args.sandbox_level}") +if args.task == "byheart": + task = tasks.SandBox( + problem=problems.ProblemByHeart(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "learnop": + task = tasks.SandBox( + problem=problems.ProblemLearnOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "guessop": + task = tasks.SandBox( + problem=problems.ProblemGuessOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "twotargets": + task = tasks.SandBox( + problem=problems.ProblemTwoTargets(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) +elif args.task == "addition": task = tasks.SandBox( - # problem, - # problems.ProblemAddition(zero_padded=False, inverted_result=False), - problems.ProblemLenId(len_max=args.sandbox_levels_len_source), + problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, @@ -570,8 +588,8 @@ log_string( ) assert ( - nb_in_train <= nb_test // 100 -), "More than 1% of test samples are in the train set" + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 +), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" ##############################