From 310911a637e559186d49feb157faafe266d66c6e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 10:01:41 +0200 Subject: [PATCH] Update. --- main.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 44035f9..84224e9 100755 --- a/main.py +++ b/main.py @@ -656,9 +656,6 @@ for i in range(args.nb_models): dropout=args.dropout, ).to(main_device) - # if i < args.nb_models//2: - # model = TokenCat(model, 10) - # model = torch.compile(model) model.id = i @@ -740,7 +737,7 @@ def generate_c_quizzes(models, nb, local_device=main_device): log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h") - return torch.cat(record) + return torch.cat(record).to("cpu") ###################################################################### @@ -876,6 +873,7 @@ time_train = 0 def multithread_execution(fun, arguments): + # Single instance, no thread if len(arguments) == 1: return fun(*(arguments[0])) @@ -954,7 +952,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - c_quizzes, agreements = multithread_execution( + c_quizzes = multithread_execution( generate_c_quizzes, [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], ) -- 2.39.5