From a0cb8172c89f80b762f863c66c4163742fa90cd5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 17:47:44 +0200 Subject: [PATCH] Update. --- main.py | 56 +++++++++++++++++++------------------------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/main.py b/main.py index 34c0987..b2a9591 100755 --- a/main.py +++ b/main.py @@ -545,7 +545,7 @@ def model_proba_solutions(m, quizzes): def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): - nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test + nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate start_time = time.perf_counter() @@ -593,22 +593,26 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): mask=(0, 0, 0, 1), ) - u = model_proba_solutions(model, solved_c_quizzes[:, model.id]) - - proba_own_solution[:, model.id] = u + proba_own_solution[:, model.id] = model_proba_solutions( + model, solved_c_quizzes[:, model.id] + ) # Now for every model not confident of its response, we pick # the most consistent from a model which is confident for s in range(proba_own_solution.size(0)): - dont_get_it = proba_own_solution[s, :] < args.proba_understands - if not dont_get_it.all(): + dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands + if not dont_get_this_quiz.all(): for model in models: - if dont_get_it[model.id]: + if dont_get_this_quiz[model.id]: + assert proba_own_solution[s, model.id] < args.proba_understands proba_other_solutions = model_proba_solutions( model, solved_c_quizzes[s] ) - proba_other_solutions[dont_get_it] = -1 + proba_other_solutions[dont_get_this_quiz] = -1 + # print( + # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n" + # ) i = proba_other_solutions.argmax() model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) teaching_count[i, model.id] += 1 @@ -647,34 +651,12 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): c_quizzes = new_bag[:128] - l = [ - quiz_machine.models_logprobas( - model, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (0, 0, 0, 1), - (0, 0, 1, 0), - ) - + quiz_machine.models_logprobas( - model, - c_quizzes, - ("f_A", "A", "f_B", "B"), - (0, 0, 0, 1), - (0, 0, 1, 0), - ) - for model in models - ] - - seq_logprobas = torch.cat([x[:, None] for x in l], dim=1) - - probas = seq_logprobas.exp() - + l = [model_proba_solutions(model, c_quizzes) for model in models] + probas = torch.cat([x[:, None] for x in l], dim=1) comments = [] - for l in seq_logprobas: - comments.append( - "proba " + " ".join([f"{x.exp().item():.02f}" for x in l]) - ) + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png" quiz_machine.problem.save_quizzes_as_image( @@ -699,7 +681,7 @@ if args.schedule_free: import schedulefree for k in range(args.nb_gpts): - log_string(f"creating model {k} and its w_quizzes") + log_string(f"creating model {k}") model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -764,10 +746,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100 + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 40 if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100 + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 40 log_string( f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" -- 2.39.5