Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 12:36:08 +0000 (14:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 12:36:08 +0000 (14:36 +0200)
main.py

diff --git a/main.py b/main.py
index 3753e9b..f6f37d6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -634,6 +634,7 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
     return x_t_minus_1
 
 
+######################################################################
 # Non-uniform transitions, to be fixed?
 
 
@@ -1350,6 +1351,68 @@ c_quizzes = None
 time_c_quizzes = 0
 time_train = 0
 
+######################################################################
+
+
+def multithread_execution(fun, arguments):
+    if len(arguments) == 1:
+        return fun(*(arguments[0]))
+
+    records, threads = [], []
+
+    def threadable_fun(*args):
+        records.append(fun(*args))
+
+    for args in arguments:
+        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+
+        # To get a different sequence between threads
+        log_string(f"dummy_rand {torch.rand(1)}")
+        threads.append(t)
+        t.start()
+
+    for t in threads:
+        t.join()
+
+    if records == []:
+        return
+    else:
+        return [
+            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+        ]
+
+
+# ----- test
+
+# nb_gpus = len(gpus)
+# nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
+
+# c_quizzes, agreements = multithread_execution(
+# generate_ae_c_quizzes,
+# [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+# )
+
+# ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+# weakest_models = ranked_models[: len(gpus)]
+
+# n_epoch = 14
+
+# multithread_execution(
+# one_ae_epoch,
+# [
+# (
+# model,
+# quiz_machine,
+# n_epoch,
+# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
+# gpu,
+# )
+# for model, gpu in zip(weakest_models, gpus)
+# ],
+# )
+
+######################################################################
+
 for n_epoch in range(current_epoch, args.nb_epochs):
     start_time = time.perf_counter()