Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 19:10:43 +0000 (21:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 19:10:43 +0000 (21:10 +0200)
main.py

diff --git a/main.py b/main.py
index 10e6bc0..6c20d2f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -315,7 +315,7 @@ def add_hints_imt(imt_set):
     corresponding value from the target into the input
 
     """
-    input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+    input, masks, targets = imt_set.unbind(dim=1)
     # h = torch.rand(masks.size(), device=masks.device) - masks
     # t = h.sort(dim=1).values[:, args.nb_hints, None]
     # mask_hints = (h < t).long()
@@ -330,7 +330,7 @@ def add_hints_imt(imt_set):
 def add_noise_imt(imt_set):
     """Replace every component of the input by a random value with
     probability args.proba_prompt_noise."""
-    input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+    input, masks, targets = imt_set.unbind(dim=1)
     noise = problem.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
         torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
@@ -489,7 +489,7 @@ def ae_generate(model, nb, local_device=main_device):
 ######################################################################
 
 
-def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
     quizzes = generate_quiz_set(
         args.nb_train_samples if train else args.nb_test_samples,
         c_quizzes,
@@ -529,7 +529,7 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
         desc=label,
         total=quizzes.size(0) // batch_size,
     ):
-        input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
+        input, masks, targets = imt.unbind(dim=1)
         if train and nb_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
@@ -555,32 +555,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 ######################################################################
 
 
-def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
-    # train
-
-    one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
-    one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
-
-    # Save some original world quizzes and the full prediction (the four grids)
-
-    quizzes = generate_quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
-    problem.save_quizzes_as_image(
-        args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
-    )
-    result = predict_full(
-        model=model,
-        input=quizzes,
-        with_noise=True,
-        with_hints=True,
-        local_device=local_device,
-    )
-    problem.save_quizzes_as_image(
-        args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
-    )
-
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
     # Save some images of the prediction results
 
-    quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+    quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
     imt_set = samples_for_prediction_imt(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
@@ -599,8 +577,31 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         correct_parts=correct_parts[:128],
     )
 
+    # Save some images of the ex nihilo generation of the four grids
+
+    result = ae_generate(model, 150, local_device=local_device).to("cpu")
+    problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_generation_{n_epoch}_{model.id}.png",
+        quizzes=result[:128],
+    )
+
+
+######################################################################
+
+
+def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+    one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device)
+
+    one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device)
+
     # Compute the test accuracy
 
+    quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+    result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+    correct = (quizzes == result).min(dim=1).values.long()
+
     nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
     model.test_accuracy = nb_correct / nb_total
 
@@ -608,13 +609,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
     )
 
-    # Save some images of the ex nihilo generation of the four grids
-
-    result = ae_generate(model, 150, local_device=local_device).to("cpu")
-    problem.save_quizzes_as_image(
-        args.result_dir,
-        f"culture_generation_{n_epoch}_{model.id}.png",
-        quizzes=result[:128],
+    save_inference_images(
+        model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
     )