temperature_cold=args.temperature_cold,
)
- recorded_too_simple.append(
- keep_good_quizzes(models, c_quizzes, required_nb_failures=0)
+ c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+
+ nc = quiz_machine.solution_nb_correct(models, c_quizzes)
+
+ count_nc = tuple(
+ n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
)
- c_quizzes = keep_good_quizzes(models, c_quizzes)
+ log_string(f"nb_correct {count_nc}")
+
+ recorded_too_simple.append(c_quizzes[nc == len(models)])
+
+ c_quizzes = c_quizzes[nc == len(models) - 1]
nb_validated[model_for_generation.id] += c_quizzes.size(0)
total_nb_validated = nb_validated.sum().item()
######################################################################
# save images
- vq = validated_quizzes[:128]
+ vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
prefix = f"culture_c_quiz_{n_epoch:04d}"
args.result_dir, prefix, vq, show_part_to_predict=False
)
- vq = too_simple_quizzes[:128]
+ vq = too_simple_quizzes
if vq.size(0) > 0:
prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
######################################################################
for n_epoch in range(current_epoch, args.nb_epochs):
+ state = {"current_epoch": n_epoch}
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
log_string(f"--- epoch {n_epoch} ----------------------------------------")
cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
)
log_string(f"wrote {filename}")
- state = {"current_epoch": n_epoch}
- filename = "state.pth"
- torch.save(state, os.path.join(args.result_dir, filename))
- log_string(f"wrote {filename}")
-
# Renew the training samples
for model in weakest_models: