Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 07:06:42 +0000 (09:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 07:06:42 +0000 (09:06 +0200)
main.py

diff --git a/main.py b/main.py
index d50837a..3cc536c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -678,30 +678,30 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 def generate_c_quizz_with_generator(generator, quiz_machine, nb):
     generator.to(main_device)
 
-    c_quizzes = quiz_machine.problem.create_empty_quizzes(
-        nb, struct=("A", "f_A", "B", "f_B")
-    )
+    struct = ("A", "f_A", "B", "f_B")
+
+    c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
+    ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1))
 
     i = F.one_hot(
         torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
         num_classes=args.nb_gpts,
     )
 
-    prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
-    len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1)
-
-    prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device)
-
-    T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[
-        None, :
-    ]
+    prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
+    prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1))
 
-    ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long()
+    prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to(
+        main_device
+    )
+    prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device)
 
     seq_logproba = torch.zeros(
         prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
     )
 
+    generator.temperature = args.temperature_hot
+
     with torch.autograd.no_grad():
         t = generator.training
         generator.eval()
@@ -709,28 +709,30 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb):
         one_batch_masked_inplace_autoregression(
             generator,
             prologued_c_quizzes,
-            ar_mask,
+            prologued_ar_mask,
             seq_logproba,
             deterministic_synthesis=False,
         )
 
         generator.train(t)
 
+    generator.reset_transformations()
+
     prologued_c_quizzes = (
         prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
     )
 
-    return prologued_c_quizzes[:, len_prolog:].to("cpu")
+    return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu")
 
 
-def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
+def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
     samples = []
 
     for _ in range(args.nb_train_samples // args.batch_size):
         while sum([x.size(0) for x in samples]) < args.batch_size:
             # Generate a bunch of quizzes
 
-            if w_quizzes:
+            if torch.rand(1).item() <= fraction_w_quizzes:
                 # Either we start with the world quizzes
                 c_quizzes = quiz_machine.problem.generate_w_quizzes(
                     args.batch_size, progress_bar=False
@@ -738,7 +740,7 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
             else:
                 # Or we use the generator itself to generate them
                 c_quizzes = generate_c_quizz_with_generator(
-                    args.batch_size, generator, quiz_machine
+                    generator, quiz_machine, args.batch_size
                 )
 
             # We remove the trivial ones
@@ -757,27 +759,24 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
 
                 probas = seq_logproba.exp()
 
-                nu = probas <= args.proba_not_understands
-                u = probas >= args.proba_understands
+                u0 = probas <= args.proba_not_understands
+                u2 = probas >= args.proba_understands
+                u1 = (u0 | u2) == False
 
                 prolog = (
-                    (nu.long() * token_prolog_0)
-                    + (((nu == False) & (u == False)).long() * token_prolog_1)
-                    + (u.long() * token_prolog_2)
+                    (u0.long() * token_prolog_0)
+                    + (u1.long() * token_prolog_1)
+                    + (u2.long() * token_prolog_2)
                 )
 
                 prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
 
-                # nb_u = u.long().sum(dim=1)
-                # nb_nu = nu.long().sum(dim=1)
-
-                # prologued_c_quizzes = prologued_c_quizzes[
-                # (nb_u + nb_nu == args.nb_gpts)
-                # & (nb_nu >= 1)
-                # & (nb_nu <= args.max_fail_to_validate)
-                # ]
+                # nb_u2 = u2.long().sum(dim=1)
+                # nb_u0 = u0.long().sum(dim=1)
+                # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
 
-                samples.append(prologued_c_quizzes)
+                if prologued_c_quizzes.size(0) > 0:
+                    samples.append(prologued_c_quizzes)
 
         # Now we yield a batch
 
@@ -788,7 +787,7 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
 
 
 def one_generator_epoch(
-    generator, quiz_machine, models, w_quizzes=True, local_device=main_device
+    generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
 ):
     model.to(local_device).train()
 
@@ -796,10 +795,11 @@ def one_generator_epoch(
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    hard_w_quizzes = []
-
     src = batches_for_generator(
-        generator=generator, quiz_machine=quiz_machine, models=models
+        generator=generator,
+        quiz_machine=quiz_machine,
+        models=models,
+        fraction_w_quizzes=fraction_w_quizzes,
     )
 
     for input in tqdm.tqdm(
@@ -1047,7 +1047,7 @@ if args.test_generator:
             generator,
             quiz_machine=quiz_machine,
             models=models,
-            w_quizzes=True,
+            fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
             local_device=main_device,
         )
 
@@ -1081,14 +1081,6 @@ if args.test_generator:
         )
         log_string(f"wrote {filename}")
 
-    one_generator_epoch(
-        generator,
-        quiz_machine=quiz_machine,
-        models=models,
-        w_quizzes=False,
-        local_device=main_device,
-    )
-
     exit(0)
 
 ######################################################################