######################################################################
-def model_ae_proba_solutions(model, input):
+def model_ae_proba_solutions(model, input, log_proba=False):
record = []
for q in input.split(args.batch_size):
loss = torch.cat(record, dim=0)
- return (-loss).exp()
+ if log_proba:
+ return -loss
+ else:
+ return (-loss).exp()
nb_diffusion_iterations = 25
)
+def save_badness_statistics(
+ n_epoch, models, c_quizzes, suffix=None, local_device=main_device
+):
+ for model in models:
+ models.eval().to(local_device)
+ c_quizzes = c_quizzes.to(local_device)
+ with torch.autograd.no_grad():
+ log_probas = sum(
+ [model_ae_proba_solutions(model, c_quizzes) for model in models]
+ )
+ i = log_probas.sort().values
+
+ suffix = "" if suffix is None else "_" + suffix
+
+ filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=quizzes[i[:128]],
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ comments=comments,
+ delta=True,
+ nrow=8,
+ )
+
+
def generate_ae_c_quizzes(models, local_device=main_device):
criteria = [
# c_quiz_criterion_only_one,
state = torch.load(os.path.join(args.result_dir, filename))
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
+ c_quizzes = state["c_quizzes"]
# total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
# total_time_training_models = state["total_time_training_models"]
# common_c_quiz_bags = state["common_c_quiz_bags"]
state = {
"current_epoch": n_epoch,
+ "c_quizzes": c_quizzes,
# "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
# "total_time_training_models": total_time_training_models,
# "common_c_quiz_bags": common_c_quiz_bags,
}
+
filename = "state.pth"
torch.save(state, os.path.join(args.result_dir, filename))
log_string(f"wrote {filename}")
log_string(f"{time_train=} {time_c_quizzes=}")
if (
- n_epoch >= 200
- and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+ min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
and time_train >= time_c_quizzes
):
+ if c_quizzes is not None:
+ save_badness_statistics(models, c_quizzes)
+
last_n_epoch_c_quizzes = n_epoch
start_time = time.perf_counter()
c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)