X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=fc55b9ce7318c45a557488bbb31d81a8d6f07f3d;hb=f2ab5fd489adebe9b34ac825d39e41f13f287cb2;hp=b88cbc4b545c553e36629d43b174cfc056250bd9;hpb=7c79c0b140c88a529962945ec5b482fe90c55581;p=culture.git diff --git a/main.py b/main.py index b88cbc4..fc55b9c 100755 --- a/main.py +++ b/main.py @@ -98,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) + +parser.add_argument( + "--grids_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None for all.", +) + +###################################################################### + parser.add_argument("--sky_height", type=int, default=6) parser.add_argument("--sky_width", type=int, default=8) @@ -250,6 +263,7 @@ elif args.problem == "grids": max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, + tasks=args.grids_tasks, ) back_accuracy = True else: