parser.add_argument("--nb_hints", type=int, default=5)
+parser.add_argument("--nb_runs", type=int, default=5)
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
parser.add_argument("--test", type=str, default=None)
c_quizzes,
local_device,
nb_have_to_be_correct=3,
- nb_have_to_be_not_correct=0,
nb_have_to_be_wrong=1,
nb_mistakes_to_be_wrong=5,
nb_hints=0,
quad_order=("A", "f_A", "B", "f_B"),
quad_mask=quad,
)
+
+ sub_correct, sub_wrong = True, True
for _ in range(nb_runs):
if nb_hints == 0:
mask_hints = None
)
nb_mistakes = (result != c_quizzes).long().sum(dim=1)
- correct = correct & (nb_mistakes == 0)
- wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+ sub_correct = sub_correct | (nb_mistakes == 0)
+ sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong)
+
+ correct = correct & sub_correct
+ wrong = wrong | sub_wrong
record_wrong.append(wrong[:, None])
nb_correct += correct.long()
if c_quizzes.size(0) > 0:
to_keep, record_wrong = quiz_validation(
- models, c_quizzes, local_device, nb_hints=args.nb_hints
+ models,
+ c_quizzes,
+ local_device,
+ nb_hints=args.nb_hints,
+ nb_runs=args.nb_runs,
)
q = c_quizzes[to_keep]
def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False):
l = []
- if solvable_only:
- to_keep, _ = quiz_validation(
- models,
- c_quizzes,
- main_device,
- nb_have_to_be_correct=1,
- nb_have_to_be_wrong=0,
- nb_hints=0,
- )
- c_quizzes = c_quizzes[to_keep]
+ with torch.autograd.no_grad():
+ if solvable_only:
+ to_keep, _ = quiz_validation(
+ models,
+ c_quizzes,
+ main_device,
+ nb_have_to_be_correct=1,
+ nb_have_to_be_wrong=0,
+ nb_hints=0,
+ )
+ c_quizzes = c_quizzes[to_keep]
- c_quizzes = c_quizzes[
- torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
- ]
+ c_quizzes = c_quizzes[
+ torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
+ ]
- with torch.autograd.no_grad():
for model in models:
model = copy.deepcopy(model).to(main_device).eval()
l.append(model_ae_proba_solutions(model, c_quizzes))