model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
):
model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
nb_train_samples, acc_train_loss = 0, 0.0
n_epoch, models, c_quizzes, suffix=None, local_device=main_device
):
for model in models:
- models.eval().to(local_device)
+ model.eval().to(local_device)
c_quizzes = c_quizzes.to(local_device)
with torch.autograd.no_grad():
log_probas = sum(
[model_ae_proba_solutions(model, c_quizzes) for model in models]
)
- i = log_probas.sort().values
+ i = log_probas.sort().indices
suffix = "" if suffix is None else "_" + suffix
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
- quizzes=quizzes[i[:128]],
+ quizzes=c_quizzes[i[:128]],
# predicted_parts=predicted_parts,
# correct_parts=correct_parts,
- comments=comments,
+ # comments=comments,
delta=True,
nrow=8,
)
+ log_string(f"wrote {filename}")
+
def generate_ae_c_quizzes(models, local_device=main_device):
criteria = [
log_string(f"{time_train=} {time_c_quizzes=}")
if (
- min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+ min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes
and time_train >= time_c_quizzes
):
if c_quizzes is not None:
- save_badness_statistics(models, c_quizzes)
+ save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
last_n_epoch_c_quizzes = n_epoch
start_time = time.perf_counter()
for model in models:
model.test_accuracy = 0
+ save_badness_statistics(n_epoch, models, c_quizzes, "before")
+
if c_quizzes is None:
log_string("no_c_quiz")
else: