Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index b88cbc4..fc55b9c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -98,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
 
+grids_tasks = ", ".join(
+    [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
+
+parser.add_argument(
+    "--grids_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+######################################################################
+
 parser.add_argument("--sky_height", type=int, default=6)
 
 parser.add_argument("--sky_width", type=int, default=8)
@@ -250,6 +263,7 @@ elif args.problem == "grids":
         max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
+        tasks=args.grids_tasks,
     )
     back_accuracy = True
 else: