corresponding value from the target into the input
"""
- input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+ input, masks, targets = imt_set.unbind(dim=1)
# h = torch.rand(masks.size(), device=masks.device) - masks
# t = h.sort(dim=1).values[:, args.nb_hints, None]
# mask_hints = (h < t).long()
def add_noise_imt(imt_set):
"""Replace every component of the input by a random value with
probability args.proba_prompt_noise."""
- input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+ input, masks, targets = imt_set.unbind(dim=1)
noise = problem.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
######################################################################
-def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
quizzes = generate_quiz_set(
args.nb_train_samples if train else args.nb_test_samples,
c_quizzes,
desc=label,
total=quizzes.size(0) // batch_size,
):
- input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
+ input, masks, targets = imt.unbind(dim=1)
if train and nb_samples % args.batch_size == 0:
model.optimizer.zero_grad()
######################################################################
-def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
- # train
-
- one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
- one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
-
- # Save some original world quizzes and the full prediction (the four grids)
-
- quizzes = generate_quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
- problem.save_quizzes_as_image(
- args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
- )
- result = predict_full(
- model=model,
- input=quizzes,
- with_noise=True,
- with_hints=True,
- local_device=local_device,
- )
- problem.save_quizzes_as_image(
- args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
- )
-
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
# Save some images of the prediction results
- quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
imt_set = samples_for_prediction_imt(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
correct_parts=correct_parts[:128],
)
+ # Save some images of the ex nihilo generation of the four grids
+
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
+
+
+######################################################################
+
+
+def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+ one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device)
+
+ one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device)
+
# Compute the test accuracy
+ quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ correct = (quizzes == result).min(dim=1).values.long()
+
nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
model.test_accuracy = nb_correct / nb_total
f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
)
- # Save some images of the ex nihilo generation of the four grids
-
- result = ae_generate(model, 150, local_device=local_device).to("cpu")
- problem.save_quizzes_as_image(
- args.result_dir,
- f"culture_generation_{n_epoch}_{model.id}.png",
- quizzes=result[:128],
+ save_inference_images(
+ model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
)