def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
- nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
+ nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
start_time = time.perf_counter()
mask=(0, 0, 0, 1),
)
- u = model_proba_solutions(model, solved_c_quizzes[:, model.id])
-
- proba_own_solution[:, model.id] = u
+ proba_own_solution[:, model.id] = model_proba_solutions(
+ model, solved_c_quizzes[:, model.id]
+ )
# Now for every model not confident of its response, we pick
# the most consistent from a model which is confident
for s in range(proba_own_solution.size(0)):
- dont_get_it = proba_own_solution[s, :] < args.proba_understands
- if not dont_get_it.all():
+ dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
+ if not dont_get_this_quiz.all():
for model in models:
- if dont_get_it[model.id]:
+ if dont_get_this_quiz[model.id]:
+ assert proba_own_solution[s, model.id] < args.proba_understands
proba_other_solutions = model_proba_solutions(
model, solved_c_quizzes[s]
)
- proba_other_solutions[dont_get_it] = -1
+ proba_other_solutions[dont_get_this_quiz] = -1
+ # print(
+ # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n"
+ # )
i = proba_other_solutions.argmax()
model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
teaching_count[i, model.id] += 1
c_quizzes = new_bag[:128]
- l = [
- quiz_machine.models_logprobas(
- model,
- c_quizzes,
- ("A", "f_A", "B", "f_B"),
- (0, 0, 0, 1),
- (0, 0, 1, 0),
- )
- + quiz_machine.models_logprobas(
- model,
- c_quizzes,
- ("f_A", "A", "f_B", "B"),
- (0, 0, 0, 1),
- (0, 0, 1, 0),
- )
- for model in models
- ]
-
- seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
-
- probas = seq_logprobas.exp()
-
+ l = [model_proba_solutions(model, c_quizzes) for model in models]
+ probas = torch.cat([x[:, None] for x in l], dim=1)
comments = []
- for l in seq_logprobas:
- comments.append(
- "proba " + " ".join([f"{x.exp().item():.02f}" for x in l])
- )
+ for l in probas:
+ comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
quiz_machine.problem.save_quizzes_as_image(
import schedulefree
for k in range(args.nb_gpts):
- log_string(f"creating model {k} and its w_quizzes")
+ log_string(f"creating model {k}")
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
######################################################################
if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples // 40
if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples // 40
log_string(
f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"