From 834abd9307ac63258501ea17bd4eb9227d4ecd52 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 Jul 2024 08:21:16 +0200 Subject: [PATCH] Update. --- main.py | 25 +++++++++++++++++++++---- quiz_machine.py | 2 -- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 6b00bbf..cdaacdf 100755 --- a/main.py +++ b/main.py @@ -284,8 +284,6 @@ problem.save_some_examples(args.result_dir) quiz_machine = quiz_machine.QuizMachine( problem=problem, - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, back_accuracy=back_accuracy, batch_size=args.physical_batch_size, result_dir=args.result_dir, @@ -414,11 +412,15 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): nb_validated = 0 + start_time = time.perf_counter() + + nb_to_generate_per_iteration = nb_to_create + while nb_validated < nb_to_create: model_for_generation = models[torch.randint(len(models), (1,))] c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, + nb_to_generate_per_iteration, model_for_generation=model_for_generation, temperature=args.generation_temperature, ) @@ -437,8 +439,19 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): if validated_quizzes is not None: nb_validated = validated_quizzes.size(0) + duration = time.perf_counter() - start_time + + if nb_validated > 0: + e = (nb_to_create - nb_validated) * duration / nb_validated + if e > 0: + e = "~" + str(datetime.timedelta(seconds=int(e))) + else: + e = "0s" + else: + e = "???" + log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (remaining time {e})" ) # store the new c_quizzes which have been validated @@ -595,6 +608,10 @@ if args.dirty_debug: args.nb_new_c_quizzes_for_train = 100 args.nb_new_c_quizzes_for_test = 10 + def compute_valid_quizzes(token_logprobas): + l = token_logprobas.sum(dim=-1).sort(dim=-1).values + return torch.rand(l[:, 0].size(), device=l.device) < 0.5 + ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index bc468d3..927a349 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -244,8 +244,6 @@ class QuizMachine: def __init__( self, problem, - nb_train_samples, - nb_test_samples, back_accuracy, batch_size, result_dir, -- 2.39.5