def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- if mask_hints is None:
- x_t = (1 - mask_generate) * x_0 + mask_generate * noise
- else:
- mask = mask_generate * (1 - mask_hints)
- x_t = (1 - mask) * x_0 + mask * noise
-
one_iteration_prediction = deterministic(mask_generate)[:, None]
+ if mask_hints is not None:
+ mask_generate = mask_generate * (1 - mask_hints)
+
+ x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+
changed = True
for it in range(nb_iterations_max):
for q in c_quizzes.split(args.inference_batch_size):
record.append(
quiz_validation(
- models,
- q,
- local_device,
- nb_have_to_be_correct,
- nb_have_to_be_wrong,
- nb_mistakes_to_be_wrong,
- nb_hints=0,
- nb_runs=1,
+ models=models,
+ c_quizzes=q,
+ local_device=local_device,
+ nb_have_to_be_correct=nb_have_to_be_correct,
+ nb_have_to_be_wrong=nb_have_to_be_wrong,
+ nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
+ nb_hints=nb_hints,
+ nb_runs=nb_runs,
)
)
quad_mask=quad,
)
- sub_correct, sub_wrong = True, True
+ sub_correct, sub_wrong = False, True
for _ in range(nb_runs):
if nb_hints == 0:
mask_hints = None
nb_correct += correct.long()
nb_wrong += wrong.long()
+ # log_string(f"{nb_hints=} {nb_correct=}")
+ # log_string(f"{nb_hints=} {nb_wrong=}")
+
to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
wrong = torch.cat(record_wrong, dim=1)
models,
c_quizzes,
main_device,
- nb_have_to_be_correct=1,
+ nb_have_to_be_correct=2,
nb_have_to_be_wrong=0,
nb_hints=0,
)
c_quizzes = c_quizzes[to_keep]
- c_quizzes = c_quizzes[
- torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
- ]
-
for model in models:
model = copy.deepcopy(model).to(main_device).eval()
l.append(model_ae_proba_solutions(model, c_quizzes))