Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:49:37 +0000 (22:49 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:49:37 +0000 (22:49 +0200)
main.py

diff --git a/main.py b/main.py
index 649c889..70ca672 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -658,11 +658,9 @@ def batch_prediction(input, proba_hints=0.0):
     return input, targets, mask_generate
 
 
-def predict(model, quizzes, local_device=main_device):
+def predict(model, input, targets, mask, local_device=main_device):
     model.eval().to(local_device)
 
-    input, targets, mask = batch_prediction(quizzes.to(local_device))
-
     input_batches = input.reshape(-1, args.physical_batch_size, input.size(1))
     targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1))
     mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1))
@@ -673,7 +671,7 @@ def predict(model, quizzes, local_device=main_device):
         zip(input_batches, targets_batches, mask_batches),
         dynamic_ncols=True,
         desc="predict",
-        total=quizzes.size(0) // args.physical_batch_size,
+        total=input.size(0) // args.physical_batch_size,
     ):
         # noise = quiz_machine.problem.pure_noise(input.size(0), input.device)
         input = (1 - mask) * input  # + mask * noise
@@ -806,20 +804,24 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 
 def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
-
     one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
 
     quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
-    result = predict(model, quizzes).to("cpu")
-
+    input, targets, mask = batch_prediction(quizzes.to(local_device))
+    result = predict(model, input, targets, mask).to("cpu")
+    mask = mask.to("cpu")
+    correct = (quizzes == result).min(dim=1).values.long()
+    correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
+        :, :, 1
+    ]
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_prediction_{n_epoch}_{model.id}.png",
         quizzes=result[:128],
+        correct_parts=correct_parts[:128],
     )
 
-    nb_correct = (quizzes == result).min(dim=1).values.long().sum()
-    model.test_accuracy = nb_correct / quizzes.size(0)
+    model.test_accuracy = correct.sum() / quizzes.size(0)
 
 
 ######################################################################