):
c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
- full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+ full_input, full_mask_generate, _ = quiz_machine.data_input(
nb,
c_quiz_bags,
data_structures=data_structures,
src = zip(
full_input.split(batch_size),
full_mask_generate.split(batch_size),
- full_mask_loss.split(batch_size),
)
if desc is not None:
total=full_input.size(0) // batch_size,
)
- for input, mask_generate, mask_loss in src:
+ for input, mask_generate in src:
yield (
input.to(local_device),
mask_generate.to(local_device),
- mask_loss.to(local_device),
)
######################################################################
#
-# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with
+# Given x_0 and t_0, t_1, ..., returns
#
-# x_{t_k} ~ P(X_{t_k} | X_0=x_0)
+# x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0)
#
-def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations):
- noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device)
+def degrade_input_to_generate(x_0, steps_nb_iterations):
+ noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- r = torch.rand(mask_generate.size(), device=mask_generate.device)
+ r = torch.rand(x_0.size(), device=x_0.device)
result = []
for n in steps_nb_iterations:
proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
- mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
- x = (1 - mask_erased) * x0 + mask_erased * noise
+ mask_erased = (r <= proba_erased[:, None]).long()
+ x = (1 - mask_erased) * x_0 + mask_erased * noise
result.append(x)
return result
######################################################################
-# Given x_t and a mas
-
-def targets_and_logits(model, input, mask_generate, prompt_noise=0.0):
- d = deterministic(mask_generate)
+def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
+ # We favor iterations near the clean signal
probs_iterations = 0.1 ** torch.linspace(
- 0, 1, args.nb_diffusion_iterations, device=input.device
+ 0, 1, args.nb_diffusion_iterations, device=x_0.device
)
probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
- probs_iterations = probs_iterations.expand(input.size(0), -1)
+ probs_iterations = probs_iterations.expand(x_0.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
- # N0 = dist.sample()
- # N1 = N0 + 1
- # N0 = (1 - d) * N0
- # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
-
- N0 = input.new_zeros(input.size(0))
N1 = dist.sample() + 1
- targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+ (x_t,) = degrade_input_to_generate(x_0, (N1,))
+
+ # Only the part to generate is degraded, the rest is a perfect
+ # noise-free conditionning
+
+ x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
+
+ # We may inject noise to prevent high-complexity non-structure
+ # signal to be generated as a way of "increasing reasoning
+ # complexity"
if prompt_noise > 0:
mask_prompt_noise = (
- torch.rand(input.size(), device=input.device) <= prompt_noise
+ torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
).long()
noise = torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
+ quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device
)
- noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
- input = (1 - mask_generate) * noisy_input + mask_generate * input
+ noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
+ x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
- input_with_mask = NTC_channel_cat(input, mask_generate)
- logits = model(input_with_mask)
+ x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+ logits_hat_x_0 = model(x_t_with_mask)
- return targets, logits
+ return logits_hat_x_0
######################################################################
return y
-def ae_generate(model, input, mask_generate, nb_iterations_max=50):
- noise = torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
+def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
+ noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- input = (1 - mask_generate) * input + mask_generate * noise
+ x_t = (1 - mask_generate) * x_0 + mask_generate * noise
- d = deterministic(mask_generate)[:, None]
+ one_iteration_prediction = deterministic(mask_generate)[:, None]
changed = True
for it in range(nb_iterations_max):
- input_with_mask = NTC_channel_cat(input, mask_generate)
- logits = model(input_with_mask)
+ x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+ logits = model(x_t_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
- final = dist.sample()
- r = prioritized_rand(final != input)
+ hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
- mask_erased = mask_generate * (r <= args.diffusion_noise_proba).long()
+ r = prioritized_rand(hat_x_0 != x_t)
- mask_to_change = d * mask_generate + (1 - d) * mask_erased
+ mask_changes = (r <= args.diffusion_noise_proba).long()
- update = (1 - mask_to_change) * input + mask_to_change * final
+ hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
+ 1 - one_iteration_prediction
+ ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0)
- if update.equal(input):
+ if hat_x_t_minus_1.equal(x_t):
# log_string(f"exit after {it+1} iterations")
break
else:
- changed = changed & (update != input).max(dim=1).values
- input[changed] = update[changed]
-
- # if it == nb_iterations_max:
- # log_string(f"remains {changed.long().sum()}")
+ changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
+ x_t[changed] = hat_x_t_minus_1[changed]
- return input
+ return x_t
######################################################################
def model_ae_proba_solutions(model, input, log_proba=False):
record = []
- for q in input.split(args.batch_size):
+ for x_0 in input.split(args.batch_size):
loss = 0
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_logits(
- model, q, mask_generate, prompt_noise=args.prompt_noise
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
)
loss_per_token = F.cross_entropy(
- logits.transpose(1, 2), targets, reduction="none"
+ logits.transpose(1, 2), x_0, reduction="none"
)
loss += (loss_per_token * mask_generate).sum(dim=1)
record.append(loss)
def model_ae_argmax_nb_disagreements(model, input):
record = []
- for q in input.split(args.batch_size):
+ for x_0 in input.split(args.batch_size):
nb_disagreements = 0
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_logits(
- model, q, mask_generate, prompt_noise=args.prompt_noise
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
)
predicted = logits.argmax(dim=-1)
nb_disagreements = nb_disagreements + (
- mask_generate * predicted != mask_generate * targets
+ mask_generate * predicted != mask_generate * x_0
).long().sum(dim=1)
record.append(nb_disagreements)
result = input.clone()
# result[...] = 0
- for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)):
+ for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_logits(
- model, q, mask_generate, prompt_noise=args.prompt_noise
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
)
- predicted = logits.argmax(dim=-1)
+ hat_x_0 = logits.argmax(dim=-1)
- r[...] = (1 - mask_generate) * r + mask_generate * predicted
+ r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0
return result
nb_test_samples, acc_test_loss = 0, 0.0
- for input, mask_generate, mask_loss in ae_batches(
+ for x_0, mask_generate in ae_batches(
quiz_machine,
args.nb_test_samples,
data_structures,
c_quizzes=c_quizzes,
desc="test",
):
- targets, logits = targets_and_logits(model, input, mask_generate)
- loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+ logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate)
+ loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+ acc_test_loss += loss.item() * x_0.size(0)
+ nb_test_samples += x_0.size(0)
log_string(
f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
- for input, mask_generate, mask_loss in ae_batches(
+ for x_0, mask_generate in ae_batches(
quiz_machine,
args.nb_test_samples,
data_structures,
c_quizzes=c_quizzes,
desc="test",
):
- targets = input.clone()
result = ae_generate(
model,
- (1 - mask_generate) * input,
+ (1 - mask_generate) * x_0,
mask_generate,
)
- correct = (result == targets).min(dim=1).values.long()
+ correct = (result == x_0).min(dim=1).values.long()
predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
:, :, 1
]
result, predicted_parts, correct_parts = bag_to_tensors(record)
- # l = [model_ae_proba_solutions(model, result) for model in other_models]
- # probas = torch.cat([x[:, None] for x in l], dim=1)
- # comments = []
-
- # for l in probas:
- # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
quizzes=result[:128],
predicted_parts=predicted_parts[:128],
correct_parts=correct_parts[:128],
- # comments=comments,
)
log_string(f"wrote {filename}")
- # Prediction with functional perturbations
-
- # input, mask_generate, mask_loss = next(
- # ae_batches(
- # quiz_machine,
- # [
- # (
- # ("A", "f_A", "B", "f_B"),
- # (0, 0, 0, 1),
- # (0, 0, 1, 0),
- # (0, 0, 0, 1),
- # ),
- # ],
- # local_device,
- # desc=None,
- # )
- # )
- # targets = input.clone()
- # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices
- # def change_theta(theta_A, theta_B):
- # theta
- # result = ae_generate(
- # model, (1 - mask_generate) * input, mask_generate
- # )
-
######################################################################
nb_train_samples, acc_train_loss = 0, 0.0
- for input, mask_generate, mask_loss in ae_batches(
+ for x_0, mask_generate in ae_batches(
quiz_machine,
args.nb_train_samples,
data_structures,
c_quizzes,
"training",
):
- input = input.to(local_device)
+ x_0 = x_0.to(local_device)
mask_generate = mask_generate.to(local_device)
- mask_loss = mask_loss.to(local_device)
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets, logits = targets_and_logits(
- model, input, mask_generate, prompt_noise=args.prompt_noise
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
)
- loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
+ loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+ acc_train_loss += loss.item() * x_0.size(0)
+ nb_train_samples += x_0.size(0)
loss.backward()