return x_t_minus_1
+######################################################################
# Non-uniform transitions, to be fixed?
time_c_quizzes = 0
time_train = 0
+######################################################################
+
+
+def multithread_execution(fun, arguments):
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
+
+ records, threads = [], []
+
+ def threadable_fun(*args):
+ records.append(fun(*args))
+
+ for args in arguments:
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+
+ # To get a different sequence between threads
+ log_string(f"dummy_rand {torch.rand(1)}")
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ if records == []:
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
+
+
+# ----- test
+
+# nb_gpus = len(gpus)
+# nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
+
+# c_quizzes, agreements = multithread_execution(
+# generate_ae_c_quizzes,
+# [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+# )
+
+# ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+# weakest_models = ranked_models[: len(gpus)]
+
+# n_epoch = 14
+
+# multithread_execution(
+# one_ae_epoch,
+# [
+# (
+# model,
+# quiz_machine,
+# n_epoch,
+# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
+# gpu,
+# )
+# for model, gpu in zip(weakest_models, gpus)
+# ],
+# )
+
+######################################################################
+
for n_epoch in range(current_epoch, args.nb_epochs):
start_time = time.perf_counter()