From 51f2897a7eb14ae72ee7eee788d876915ead4370 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 20:47:42 +0200 Subject: [PATCH] Update. --- main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 2b71950..41efc86 100755 --- a/main.py +++ b/main.py @@ -410,12 +410,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 start_time = time.perf_counter() - nb_validated = torch.zeros(len(models)) + nb_validated = torch.zeros(len(models), dtype=torch.int64) while nb_validated.sum() < nb_to_create: # We balance the number of quizzes per model - model_for_generation = models[nb_validated.argmin()] + model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0] + print(nb_validated, "using", model_for_generation.id) c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, @@ -426,7 +427,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = keep_good_quizzes(models, c_quizzes) - nb_validated[model.id] += c_quizzes.size(0) + nb_validated[model_for_generation.id] += c_quizzes.size(0) total_nb_validated = nb_validated.sum().item() recorded.append(c_quizzes) @@ -442,7 +443,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)" ) validated_quizzes = torch.cat(recorded, dim=0) -- 2.39.5