X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=68b946a1ed19580d232fb3593a0f2b3541fded57;hb=b718ef527d4bfb014a9ad564bb5199c7d0780aa9;hp=af94979e937d852912dc79c01ada589436de461f;hpb=d2844d7a2d09ef38dc6f62d5e131059cccc872c5;p=culture.git diff --git a/main.py b/main.py index af94979..68b946a 100755 --- a/main.py +++ b/main.py @@ -42,6 +42,10 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + parser.add_argument("--nb_epochs", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) @@ -56,6 +60,8 @@ parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") +######################################## + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -70,6 +76,8 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) +######################################## + parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--no_checkpoint", action="store_true", default=False) @@ -570,8 +578,8 @@ log_string( ) assert ( - nb_in_train <= nb_test // 100 -), "More than 1% of test samples are in the train set" + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 +), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" ##############################