Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 4cf4d59..fc55b9c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -18,6 +18,8 @@ import sky, grids, quiz_machine
 
 import threading
 
+import torch.multiprocessing as mp
+
 # world quizzes vs. culture quizzes
 
 ######################################################################
@@ -96,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
 
+grids_tasks = ", ".join(
+    [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
+
+parser.add_argument(
+    "--grids_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+######################################################################
+
 parser.add_argument("--sky_height", type=int, default=6)
 
 parser.add_argument("--sky_width", type=int, default=8)
@@ -248,11 +263,14 @@ elif args.problem == "grids":
         max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
+        tasks=args.grids_tasks,
     )
     back_accuracy = True
 else:
     raise ValueError
 
+problem.save_some_examples(args.result_dir)
+
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
     nb_train_samples=args.nb_train_samples,
@@ -345,8 +363,6 @@ def one_epoch(model, quiz_machine, local_device=None):
 
     run_tests(model, quiz_machine, deterministic_synthesis=False)
 
-    model.TRAINING_LOCK.release()
-
 
 ######################################################################
 
@@ -447,7 +463,6 @@ for k in range(args.nb_gpts):
 
     model.main_test_accuracy = 0.0
     model.id = k
-    model.TRAINING_LOCK = threading.Lock()
 
     model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
     quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
@@ -545,20 +560,21 @@ for n_epoch in range(args.nb_epochs):
 
     weakest_models = ranked_models[: args.nb_gpus]
 
+    threads = []
+
     for gpu_id, model in enumerate(weakest_models):
-        model.TRAINING_LOCK.acquire()
+        log_string(f"training model {model.id}")
 
-        log_string(
-            f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+        t = threading.Thread(
+            target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
         )
 
-        threading.Thread(
-            target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
-        ).start()
+        threads.append(t)
 
-    for model in weakest_models:
-        model.TRAINING_LOCK.acquire()
-        model.TRAINING_LOCK.release()
+        t.start()
+
+    for t in threads:
+        t.join()
 
     ##################################################
     # Replace a fraction of the w_quizzes with fresh ones