Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 19:02:38 +0000 (21:02 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 19:02:38 +0000 (21:02 +0200)
main.py

diff --git a/main.py b/main.py
index 4903585..cce747a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -682,7 +682,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 # 2->proba>=proba_understands and 1 otherwise.
 
 
-def generate_c_quizz_with_generator(generator, quiz_machine, nb):
+def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
     generator.to(main_device)
 
     struct = ("A", "f_A", "B", "f_B")
@@ -695,13 +695,13 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb):
         num_classes=args.nb_gpts,
     )
 
-    prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
-    prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1))
+    prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
+    prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1))
 
-    prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to(
+    prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to(
         main_device
     )
-    prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device)
+    prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device)
 
     seq_logproba = torch.zeros(
         prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
@@ -729,7 +729,9 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb):
         prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
     )
 
-    return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu")
+    c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
+
+    return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
 
 
 def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
@@ -746,7 +748,7 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.
                 )
             else:
                 # Or we use the generator itself to generate them
-                c_quizzes = generate_c_quizz_with_generator(
+                c_quizzes, _ = generate_c_quizzes_with_generator(
                     generator, quiz_machine, args.batch_size
                 )
 
@@ -770,13 +772,13 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.
                 u2 = probas >= args.proba_understands
                 u1 = (u0 | u2) == False
 
-                prolog = (
+                prologs = (
                     (u0.long() * token_prolog_0)
                     + (u1.long() * token_prolog_1)
                     + (u2.long() * token_prolog_2)
                 )
 
-                prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+                prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
 
                 # nb_u2 = u2.long().sum(dim=1)
                 # nb_u0 = u0.long().sum(dim=1)
@@ -1031,17 +1033,6 @@ if args.dirty_debug:
 ######################################################################
 
 if args.test_generator:
-    filename = f"generator.pth"
-
-    try:
-        d = torch.load(os.path.join(args.result_dir, filename))
-        generator.load_state_dict(d[0])
-        generator.main_test_accuracy = d[1]
-        log_string(f"successfully loaded {filename}")
-    except FileNotFoundError:
-        log_string(f"cannot find {filename}")
-        pass
-
     token_prolog_0 = vocabulary_size + 0
     token_prolog_1 = vocabulary_size + 1
     token_prolog_2 = vocabulary_size + 2
@@ -1060,6 +1051,17 @@ if args.test_generator:
 
     generator.main_test_accuracy = 0.0
 
+    filename = f"generator.pth"
+
+    try:
+        d = torch.load(os.path.join(args.result_dir, filename))
+        generator.load_state_dict(d[0])
+        generator.main_test_accuracy = d[1]
+        log_string(f"successfully loaded {filename}")
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
+
     for n_epoch in range(args.nb_epochs):
         one_generator_epoch(
             generator,
@@ -1076,7 +1078,7 @@ if args.test_generator:
         )
         log_string(f"wrote {filename}")
 
-        c_quizzes = generate_c_quizz_with_generator(
+        c_quizzes, prologs = generate_c_quizzes_with_generator(
             generator, quiz_machine, args.batch_size
         )
 
@@ -1086,12 +1088,34 @@ if args.test_generator:
             models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
         )
 
-        print(seq_logproba.exp())
+        probas = seq_logproba.exp()
+
+        u0 = probas <= args.proba_not_understands
+        u2 = probas >= args.proba_understands
+        u1 = (u0 | u2) == False
+
+        predicted_prologs = (
+            (u0.long() * token_prolog_0)
+            + (u1.long() * token_prolog_1)
+            + (u2.long() * token_prolog_2)
+        )
 
         comments = []
 
-        for l in seq_logproba:
-            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+        nb_errors = (predicted_prologs != prologs).long().sum()
+        nb_total = prologs.numel()
+
+        log_string(f"generator_error {nb_errors} / {nb_total}")
+
+        def readable(prologs):
+            return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
+
+        for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
+            sa = "prolog " + " ".join(
+                [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]
+            )
+            sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa])
+            comments.append(sa + "\n" + sp)
 
         filename = f"generator_batch_{n_epoch:04d}.png"
         quiz_machine.problem.save_quizzes_as_image(