parser.add_argument("--inference_batch_size", type=int, default=25)
-parser.add_argument("--nb_train_samples", type=int, default=100000)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
parser.add_argument("--nb_test_samples", type=int, default=1000)
wrong = torch.cat(record_wrong, dim=1)
- return to_keep, wrong
+ return to_keep, nb_correct, nb_wrong, wrong
######################################################################
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- to_keep, record_wrong = quiz_validation(
+ to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
models,
c_quizzes,
local_device,
with torch.autograd.no_grad():
if solvable_only:
- to_keep, _ = quiz_validation(
+ to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
models,
c_quizzes,
main_device,
)
c_quizzes = c_quizzes[to_keep]
- for model in models:
- model = copy.deepcopy(model).to(main_device).eval()
- l.append(model_ae_proba_solutions(model, c_quizzes))
-
- probas = torch.cat([x[:, None] for x in l], dim=1)
-
comments = []
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+ for c, w in zip(nb_correct, nb_wrong):
+ comments.append("nb_correct {c} nb_wrong {w}")
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,