From c4c6c79528c7e64d1c8449ebb17f4b277856ec90 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 09:27:44 +0200 Subject: [PATCH] Update. --- main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 22854a9..380be1e 100755 --- a/main.py +++ b/main.py @@ -67,7 +67,7 @@ parser.add_argument("--nb_train_alien_samples", type=int, default=0) parser.add_argument("--nb_test_alien_samples", type=int, default=0) -parser.add_argument("--nb_c_quizzes", type=int, default=10000) +parser.add_argument("--nb_c_quizzes", type=int, default=2500) parser.add_argument("--c_quiz_multiplier", type=int, default=1) @@ -710,10 +710,10 @@ def generate_c_quizzes(models, nb, local_device=main_device): nb_wrong >= args.nb_have_to_be_wrong ) - nb_validated += to_keep.long().sum() + nb_validated += to_keep.long().sum().item() record.append(c_quizzes[to_keep]) - log_string(f"generate_c_quizzes {nb_validated}") + # log_string(f"generate_c_quizzes {nb_validated}") ##################### @@ -722,8 +722,8 @@ def generate_c_quizzes(models, nb, local_device=main_device): if last_log < 0 or duration > last_log + 10: last_log = duration if nb_validated > 0: - if nb_validated < wanted_nb: - d = (wanted_nb - nb_validated) * duration / nb_validated + if nb_validated < nb: + d = (nb - nb_validated) * duration / nb_validated e = ( datetime.datetime.now() + datetime.timedelta(seconds=d) ).strftime("%a %H:%M") @@ -740,7 +740,7 @@ def generate_c_quizzes(models, nb, local_device=main_device): duration = time.perf_counter() - start_time - log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") + log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h") return torch.cat(record) -- 2.39.5