X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=ed4adf52b62731b06995522e58dc7a49cb58352f;hb=59600257e0eda86816a43676c5ffbe598d78bdb5;hp=af94979e937d852912dc79c01ada589436de461f;hpb=d2844d7a2d09ef38dc6f62d5e131059cccc872c5;p=picoclvr.git diff --git a/main.py b/main.py index af94979..ed4adf5 100755 --- a/main.py +++ b/main.py @@ -42,6 +42,10 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + parser.add_argument("--nb_epochs", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) @@ -56,6 +60,8 @@ parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") +######################################## + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -70,6 +76,8 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) +######################################## + parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--no_checkpoint", action="store_true", default=False) @@ -357,7 +365,8 @@ if args.task == "sandbox": task = tasks.SandBox( # problem, # problems.ProblemAddition(zero_padded=False, inverted_result=False), - problems.ProblemLenId(len_max=args.sandbox_levels_len_source), + # problems.ProblemLenId(len_max=args.sandbox_levels_len_source), + problems.ProblemTwoTargets(len_total=16, len_targets=4), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, @@ -570,8 +579,8 @@ log_string( ) assert ( - nb_in_train <= nb_test // 100 -), "More than 1% of test samples are in the train set" + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 +), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" ##############################