parser.add_argument("--nb_models", type=int, default=5)
+parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
+
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
parser.add_argument("--max_fail_to_validate", type=int, default=3)
######################################################################
+
+def bag_len(bag):
+ return sum([x[0].size(0) for x in bag])
+
+
+def bag_to_tensors(bag):
+ return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0])))
+
+
+######################################################################
+
# If we need to move an optimizer to a different device
######################################################################
-nb_iterations = 25
-probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
-probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-
def ae_batches(
quiz_machine,
changed = True
for it in range(nb_iterations_max):
+ print(f"{it=} {nb_iterations_max=}")
+
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
changed = changed & (update != input).max(dim=1).values
input[changed] = update[changed]
- log_string(f"remains {changed.long().sum()}")
+ if it == nb_iterations_max:
+ log_string(f"remains {changed.long().sum()}")
return input
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_prediction(
- probs_iterations, model, q, mask_generate
- )
+ targets, logits = targets_and_prediction(model, q, mask_generate)
loss_per_token = F.cross_entropy(
logits.transpose(1, 2), targets, reduction="none"
)
return (-loss).exp()
+nb_diffusion_iterations = 25
+
+
def degrade_input(input, mask_generate, nb_iterations, noise_proba):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
return result
-def targets_and_prediction(probs_iterations, model, input, mask_generate):
+def targets_and_prediction(model, input, mask_generate):
d = deterministic(mask_generate)
- p = probs_iterations.expand(input.size(0), -1)
- dist = torch.distributions.categorical.Categorical(probs=p)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.nb_diffusion_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(input.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
N0 = dist.sample()
N1 = N0 + 1
N0 = (1 - d) * N0
- N1 = (1 - d) * N1 + d * nb_iterations
+ N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
targets, input = degrade_input(
input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
return targets, logits
-def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_device):
+def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
with torch.autograd.no_grad():
model.eval().to(local_device)
local_device,
"test",
):
- targets, logits = targets_and_prediction(
- probs_iterations, model, input, mask_generate
- )
+ targets, logits = targets_and_prediction(model, input, mask_generate)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
model.test_accuracy = nb_correct / nb_total
- for f, record in [("prediction", record_d), ("generation", record_nd)]:
- filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
- result, predicted_parts, correct_parts = (
- torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2]
- )
+ # for f, record in [("prediction", record_d), ("generation", record_nd)]:
+ # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
- l = [model_ae_proba_solutions(model, result) for model in other_models]
- probas = torch.cat([x[:, None] for x in l], dim=1)
- comments = []
+ # result, predicted_parts, correct_parts = bag_to_tensors(record)
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+ # l = [model_ae_proba_solutions(model, result) for model in other_models]
+ # probas = torch.cat([x[:, None] for x in l], dim=1)
+ # comments = []
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result,
- predicted_parts=predicted_parts,
- correct_parts=correct_parts,
- comments=comments,
- )
- log_string(f"wrote {filename}")
+ # for l in probas:
+ # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=result,
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ # comments=comments,
+ # )
+ # log_string(f"wrote {filename}")
# Prediction with functional perturbations
local_device,
"training",
):
+ input = input.to(local_device)
+ mask_generate = mask_generate.to(local_device)
+ mask_loss = mask_loss.to(local_device)
+
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets, logits = targets_and_prediction(
- probs_iterations, model, input, mask_generate
+ targets, logits = targets_and_prediction(model, input, mask_generate)
+
+ print(
+ f"{input.device=} {logits.device=} {targets.device=} {logits.device=} {mask_loss.device=}"
)
+
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
)
- run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
+ # run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
######################################################################
######################################################################
+
+def c_quiz_criterion_one_good_one_bad(probas):
+ return (probas.max(dim=1).values >= 0.8) & (probas.min(dim=1).values <= 0.2)
+
+
+def c_quiz_criterion_diff(probas):
+ return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
+
+
+def c_quiz_criterion_two_certains(probas):
+ return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
+
+
+def generate_ae_c_quizzes(models, local_device=main_device):
+ criteria = [
+ c_quiz_criterion_one_good_one_bad,
+ c_quiz_criterion_diff,
+ c_quiz_criterion_two_certains,
+ ]
+
+ for m in models:
+ m.eval().to(local_device)
+
+ quad_order = ("A", "f_A", "B", "f_B")
+
+ template = quiz_machine.problem.create_empty_quizzes(
+ nb=args.batch_size, quad_order=quad_order
+ ).to(local_device)
+
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
+ )
+
+ records = [[] for _ in criteria]
+
+ with torch.autograd.no_grad():
+ while min([bag_len(bag) for bag in records]) < 128:
+ bl = [bag_len(bag) for bag in records]
+ log_string(f"bag_len {bl}")
+
+ model = models[torch.randint(len(models), (1,)).item()]
+ result = ae_generate(model, template, mask_generate, 0.0)
+
+ probas = torch.cat(
+ [model_ae_proba_solutions(model, result)[:, None] for model in models],
+ dim=1,
+ )
+ for c, r in zip(criteria, records):
+ q = result[c(probas)]
+ if q.size(0) > 0:
+ r.append(q)
+
+ # for f, record in [("prediction", record_d), ("generation", record_nd)]:
+ # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+
+ # result, predicted_parts, correct_parts = bag_to_tensors(record)
+
+ # l = [model_ae_proba_solutions(model, result) for model in models]
+ # probas = torch.cat([x[:, None] for x in l], dim=1)
+ # comments = []
+
+ # for l in probas:
+ # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=result,
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ # comments=comments,
+ # )
+ # log_string(f"wrote {filename}")
+
+
+######################################################################
+
current_epoch = 0
if args.resume:
for t in threads:
t.join()
+ generate_ae_c_quizzes(models, local_device=main_device)
+
# --------------------------------------------------------------------
for model in models: