def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- one_iteration_prediction = deterministic(mask_generate)[:, None]
+ single_iteration = deterministic(mask_generate)[:, None]
if mask_hints is not None:
mask_generate = mask_generate * (1 - mask_hints)
hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
- hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
- 1 - one_iteration_prediction
+ hat_x_t_minus_1 = single_iteration * hat_x_0 + (
+ 1 - single_iteration
) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
if hat_x_t_minus_1.equal(x_t):
- # log_string(f"exit after {it+1} iterations")
break
else:
changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
return (-loss).exp()
-def model_ae_argmax_nb_mistakes(model, input):
- record = []
-
- for x_0 in input.split(args.batch_size):
- nb_mistakes = 0
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
-
- logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
- )
-
- predicted = logits.argmax(dim=-1)
-
- nb_mistakes = nb_mistakes + (
- mask_generate * predicted != mask_generate * x_0
- ).long().sum(dim=1)
-
- record.append(nb_mistakes)
-
- return torch.cat(record, dim=0)
-
-
-######################################################################
-
-
-def model_ae_argmax_predictions(model, input):
- result = input.clone()
- # result[...] = 0
-
- for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
- logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
- )
-
- hat_x_0 = logits.argmax(dim=-1)
-
- r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0
-
- return result
-
-
######################################################################
######################################################################
-def save_badness_statistics(
- n_epoch, models, c_quizzes, suffix=None, local_device=main_device
-):
- for model in models:
- model.eval().to(local_device)
- c_quizzes = c_quizzes.to(local_device)
- with torch.autograd.no_grad():
- log_probas = sum(
- [model_ae_proba_solutions(model, c_quizzes) for model in models]
- )
- i = log_probas.sort().indices
-
- suffix = "" if suffix is None else "_" + suffix
-
- filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=c_quizzes[i[:128]],
- # predicted_parts=predicted_parts,
- # correct_parts=correct_parts,
- # comments=comments,
- delta=True,
- nrow=8,
- )
-
- log_string(f"wrote {filename}")
-
-
-######################################################################
-
-
def quiz_validation(
models,
c_quizzes,