From 83b9f0642e6d592ae51c4989ce7be8ef2a2f7469 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 20 Sep 2024 21:10:43 +0200 Subject: [PATCH] Update. --- main.py | 66 +++++++++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/main.py b/main.py index 10e6bc0..6c20d2f 100755 --- a/main.py +++ b/main.py @@ -315,7 +315,7 @@ def add_hints_imt(imt_set): corresponding value from the target into the input """ - input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] + input, masks, targets = imt_set.unbind(dim=1) # h = torch.rand(masks.size(), device=masks.device) - masks # t = h.sort(dim=1).values[:, args.nb_hints, None] # mask_hints = (h < t).long() @@ -330,7 +330,7 @@ def add_hints_imt(imt_set): def add_noise_imt(imt_set): """Replace every component of the input by a random value with probability args.proba_prompt_noise.""" - input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] + input, masks, targets = imt_set.unbind(dim=1) noise = problem.pure_noise(input.size(0), input.device) change = (1 - masks) * ( torch.rand(input.size(), device=input.device) < args.proba_prompt_noise @@ -489,7 +489,7 @@ def ae_generate(model, nb, local_device=main_device): ###################################################################### -def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): +def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): quizzes = generate_quiz_set( args.nb_train_samples if train else args.nb_test_samples, c_quizzes, @@ -529,7 +529,7 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): desc=label, total=quizzes.size(0) // batch_size, ): - input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2] + input, masks, targets = imt.unbind(dim=1) if train and nb_samples % args.batch_size == 0: model.optimizer.zero_grad() @@ -555,32 +555,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): ###################################################################### -def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): - # train - - one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True) - one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False) - - # Save some original world quizzes and the full prediction (the four grids) - - quizzes = generate_quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device) - problem.save_quizzes_as_image( - args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes - ) - result = predict_full( - model=model, - input=quizzes, - with_noise=True, - with_hints=True, - local_device=local_device, - ) - problem.save_quizzes_as_image( - args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result - ) - +def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device): # Save some images of the prediction results - quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) + quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -599,8 +577,31 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): correct_parts=correct_parts[:128], ) + # Save some images of the ex nihilo generation of the four grids + + result = ae_generate(model, 150, local_device=local_device).to("cpu") + problem.save_quizzes_as_image( + args.result_dir, + f"culture_generation_{n_epoch}_{model.id}.png", + quizzes=result[:128], + ) + + +###################################################################### + + +def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): + one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device) + + one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device) + # Compute the test accuracy + quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) + imt_set = samples_for_prediction_imt(quizzes.to(local_device)) + result = ae_predict(model, imt_set, local_device=local_device).to("cpu") + correct = (quizzes == result).min(dim=1).values.long() + nb_correct, nb_total = correct.sum().item(), quizzes.size(0) model.test_accuracy = nb_correct / nb_total @@ -608,13 +609,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)" ) - # Save some images of the ex nihilo generation of the four grids - - result = ae_generate(model, 150, local_device=local_device).to("cpu") - problem.save_quizzes_as_image( - args.result_dir, - f"culture_generation_{n_epoch}_{model.id}.png", - quizzes=result[:128], + save_inference_images( + model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device ) -- 2.39.5