Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 19:04:10 +0000 (21:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 19:04:10 +0000 (21:04 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 36b369e..b6c62cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -847,25 +847,24 @@ def test_ae(local_device=main_device):
         model.train()
         nb_train_samples, acc_train_loss = 0, 0.0
 
-        data_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
-        ]
-
         full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-            args.nb_train_samples, data_structures=data_structures
+            args.nb_train_samples
         )
 
         src = zip(
-            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+            full_input.split(args.batch_size),
+            full_mask_generate.split(args.batch_size),
+            full_mask_loss.split(args.batch_size),
         )
 
-        for input, mask_loss in tqdm.tqdm(
+        for input, mask_generate, mask_loss in tqdm.tqdm(
             src,
             dynamic_ncols=True,
             desc="training",
             total=full_input.size(0) // args.batch_size,
         ):
             input = input.to(local_device)
+            mask_generate = mask_generate.to(local_device)
             mask_loss = mask_loss.to(local_device)
 
             if nb_train_samples % args.batch_size == 0:
@@ -896,21 +895,25 @@ def test_ae(local_device=main_device):
             nb_test_samples, acc_test_loss = 0, 0.0
 
             full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-                args.nb_test_samples, data_structures=data_structures
+                args.nb_test_samples
             )
 
             src = zip(
-                full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+                full_input.split(args.batch_size),
+                full_mask_generate.split(args.batch_size),
+                full_mask_loss.split(args.batch_size),
             )
 
-            for input, mask_loss in tqdm.tqdm(
+            for input, mask_generate, mask_loss in tqdm.tqdm(
                 src,
                 dynamic_ncols=True,
                 desc="testing",
                 total=full_input.size(0) // args.batch_size,
             ):
                 input = input.to(local_device)
+                mask_generate = mask_generate.to(local_device)
                 mask_loss = mask_loss.to(local_device)
+
                 targets = input
                 input = (mask_generate == 0).long() * input
                 output = model(mygpt.BracketedSequence(input)).x
@@ -920,10 +923,9 @@ def test_ae(local_device=main_device):
 
             log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
 
-            input, mask_generate, mask_loss = quiz_machine.data_input(
-                128, data_structures=data_structures
-            )
+            input, mask_generate, mask_loss = quiz_machine.data_input(128)
             input = input.to(local_device)
+            mask_generate = mask_generate.to(local_device)
             mask_loss = mask_loss.to(local_device)
             targets = input
             input = (mask_generate == 0).long() * input
@@ -935,12 +937,41 @@ def test_ae(local_device=main_device):
             result[:, 1 * L] = input[:, 1 * L]
             result[:, 2 * L] = input[:, 2 * L]
             result[:, 3 * L] = input[:, 3 * L]
+            correct = (result == targets).min(dim=1).values.long()
+            predicted_parts = input.new(input.size(0), 4)
+
+            nb = 0
+
+            # We consider all the configurations that we train for
+            for struct, quad_generate, _, _ in quiz_machine.test_structures:
+                i = quiz_machine.problem.indices_select(quizzes=input, struct=struct)
+                nb += i.long().sum()
+
+                predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
+                    None, :
+                ]
+                solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
+                correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
+
+            assert nb == input.size(0)
+
+            nb_correct = (correct == 1).long().sum()
+            nb_total = (correct != 0).long().sum()
+
+            self.logger(
+                f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+            )
+
+            correct_parts = predicted_parts * correct[:, None]
+
             filename = f"prediction_ae_{n_epoch:04d}.png"
 
             quiz_machine.problem.save_quizzes_as_image(
                 args.result_dir,
                 filename,
                 quizzes=result,
+                predicted_parts=predicted_parts,
+                correct_parts=correct_parts,
             )
 
             log_string(f"wrote {filename}")
index 8379a57..a744224 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -164,7 +164,7 @@ class TrainablePositionalEncoding(nn.Module):
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+            bs.slice() + self.pe[:, bs.first : bs.first + bs.nb, :]
         )
 
         return BracketedSequence(self.cache_y, bs.first, bs.nb)