X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=9a3d34633bd88cff4ac1e05ee77989e4d725e7b1;hb=960c93d7c0aea41d180814c46d3a05686a426764;hp=ed4adf52b62731b06995522e58dc7a49cb58352f;hpb=59600257e0eda86816a43676c5ffbe598d78bdb5;p=picoclvr.git diff --git a/main.py b/main.py index ed4adf5..9a3d346 100755 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ parser.add_argument( "--task", type=str, default="sandbox", - help="sandbox, picoclvr, mnist, maze, snake, stack, expr, rpl, world", + help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -99,17 +99,6 @@ parser.add_argument("--rpl_nb_runs", type=int, default=8) parser.add_argument("--rpl_no_prog", action="store_true", default=False) -############################## -# sandbox options - -parser.add_argument("--sandbox_level", type=int, default=0) - -parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) - -parser.add_argument("--sandbox_levels_len_source", type=int, default=6) - -parser.add_argument("--sandbox_levels_len_result", type=int, default=8) - ############################## # picoclvr options @@ -341,32 +330,52 @@ picoclvr_pruner_eval = ( ###################################################################### -if args.task == "sandbox": - if args.sandbox_level == 0: - problem = problems.ProblemLevel0( - nb_sentences=args.sandbox_levels_nb_items, - len_prompt=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - elif args.sandbox_level == 1: - problem = problems.ProblemLevel1( - nb_operators=args.sandbox_levels_nb_items, - len_source=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - elif args.sandbox_level == 2: - problem = problems.ProblemLevel2( - len_source=args.sandbox_levels_len_source, - len_result=args.sandbox_levels_len_result, - ) - else: - raise ValueError(f"Unknown sandbox level {args.sandbox_level}") +if args.task == "byheart": + task = tasks.SandBox( + problem=problems.ProblemByHeart(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "learnop": + task = tasks.SandBox( + problem=problems.ProblemLearnOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "guessop": + task = tasks.SandBox( + problem=problems.ProblemGuessOperator(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + + +elif args.task == "twotargets": + task = tasks.SandBox( + problem=problems.ProblemTwoTargets(), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) +elif args.task == "addition": task = tasks.SandBox( - # problem, - # problems.ProblemAddition(zero_padded=False, inverted_result=False), - # problems.ProblemLenId(len_max=args.sandbox_levels_len_source), - problems.ProblemTwoTargets(len_total=16, len_targets=4), + problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size,