Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 17:14:47 +0000 (19:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 17:14:47 +0000 (19:14 +0200)
main.py

diff --git a/main.py b/main.py
index f6f37d6..80e99fd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1228,6 +1228,8 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False):
     l = []
 
+    c_quizzes = c_quizzes.to(main_device)
+
     with torch.autograd.no_grad():
         if solvable_only:
             to_keep, _ = quiz_validation(
@@ -1374,43 +1376,23 @@ def multithread_execution(fun, arguments):
     for t in threads:
         t.join()
 
-    if records == []:
+    if records[0] is None:
         return
+
     else:
         return [
-            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+            torch.cat([x[k].to("cpu") for x in records], dim=0)
+            for k in range(len(records[0]))
         ]
 
 
 # ----- test
 
-# nb_gpus = len(gpus)
-# nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
-
-# c_quizzes, agreements = multithread_execution(
-# generate_ae_c_quizzes,
-# [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
-# )
-
 # ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 # weakest_models = ranked_models[: len(gpus)]
 
 # n_epoch = 14
 
-# multithread_execution(
-# one_ae_epoch,
-# [
-# (
-# model,
-# quiz_machine,
-# n_epoch,
-# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
-# gpu,
-# )
-# for model, gpu in zip(weakest_models, gpus)
-# ],
-# )
-
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
@@ -1459,35 +1441,40 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         # --------------------------------------------------------------------
 
-        records, threads = [], []
-
-        start_time = time.perf_counter()
-
-        if len(gpus) > 1:
-            for gpu in gpus:
-                t = threading.Thread(
-                    target=thread_generate_ae_c_quizzes,
-                    daemon=True,
-                    args=(models, nb_c_quizzes_to_generate, records, gpu),
-                )
-
-                # To get a different sequence between threads
-                log_string(f"dummy {torch.rand(1)}")
-                threads.append(t)
-                t.start()
-
-            for t in threads:
-                t.join()
-
-        else:
-            records.append(
-                generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0])
-            )
-
-        time_c_quizzes = int(time.perf_counter() - start_time)
+        c_quizzes, agreements = multithread_execution(
+            generate_ae_c_quizzes,
+            [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+        )
 
-        c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
-        agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
+        ## records, threads = [], []
+        ##
+        ## start_time = time.perf_counter()
+        ##
+        ## if len(gpus) > 1:
+        ## for gpu in gpus:
+        ## t = threading.Thread(
+        ## target=thread_generate_ae_c_quizzes,
+        ## daemon=True,
+        ## args=(models, nb_c_quizzes_to_generate, records, gpu),
+        ## )
+        ##
+        ## # To get a different sequence between threads
+        ## log_string(f"dummy {torch.rand(1)}")
+        ## threads.append(t)
+        ## t.start()
+        ##
+        ## for t in threads:
+        ## t.join()
+        ##
+        ## else:
+        ## records.append(
+        ## generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0])
+        ## )
+        ##
+        ## time_c_quizzes = int(time.perf_counter() - start_time)
+        ##
+        ## c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
+        ## agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
 
         # --------------------------------------------------------------------
 
@@ -1520,40 +1507,54 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
-    threads = []
-
-    start_time = time.perf_counter()
-
-    if len(gpus) > 1:
-        for gpu, model in zip(gpus, weakest_models):
-            log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
-            if c_quizzes is None:
-                c_quizzes_for_this_model = None
-            else:
-                c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
-
-            t = threading.Thread(
-                target=one_ae_epoch,
-                daemon=True,
-                args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
+    multithread_execution(
+        one_ae_epoch,
+        [
+            (
+                model,
+                quiz_machine,
+                n_epoch,
+                None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
+                gpu,
             )
+            for model, gpu in zip(weakest_models, gpus)
+        ],
+    )
 
-            threads.append(t)
-
-            t.start()
-
-        for t in threads:
-            t.join()
-
-    else:
-        model = weakest_models[0]
-        log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
-        if c_quizzes is None:
-            c_quizzes_for_this_model = None
-        else:
-            c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
-
-        one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpus[0])
+    ## threads = []
+    ##
+    ## start_time = time.perf_counter()
+    ##
+    ## if len(gpus) > 1:
+    ## for gpu, model in zip(gpus, weakest_models):
+    ## log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+    ## if c_quizzes is None:
+    ## c_quizzes_for_this_model = None
+    ## else:
+    ## c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
+    ##
+    ## t = threading.Thread(
+    ## target=one_ae_epoch,
+    ## daemon=True,
+    ## args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
+    ## )
+    ##
+    ## threads.append(t)
+    ##
+    ## t.start()
+    ##
+    ## for t in threads:
+    ## t.join()
+    ##
+    ## else:
+    ## model = weakest_models[0]
+    ## log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+    ## if c_quizzes is None:
+    ## c_quizzes_for_this_model = None
+    ## else:
+    ## c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
+    ##
+    ## one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpus[0])
 
     time_train += int(time.perf_counter() - start_time)