Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 19:23:26 +0000 (21:23 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 19:23:26 +0000 (21:23 +0200)
main.py
mygpt.py
tasks.py

diff --git a/main.py b/main.py
index ee4e9e5..8033836 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -73,7 +73,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--check", action="store_true", default=False)
+parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
 
@@ -182,9 +182,9 @@ for n in vars(args):
 
 ######################################################################
 
-if args.check:
-    args.nb_train_samples = 25000
-    args.nb_test_samples = 1000
+if args.dirty_debug:
+    args.nb_train_samples = 2500
+    args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
@@ -339,12 +339,12 @@ def create_quizzes(
 ):
     kept = []
 
-    sum_logits = 0
+    sum_logits, sum_nb_quizzes = 0, 0
 
     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
         nb_to_generate = 4 * (nb_for_train + nb_for_test)
 
-        new_quizzes, nb_correct, _sum_logits = task.create_new_quizzes(
+        new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
             n_epoch=n_epoch,
             result_dir=args.result_dir,
             logger=log_string,
@@ -354,12 +354,18 @@ def create_quizzes(
             desired_average_logits=desired_average_logits,
         )
 
-        sum_logits += _sum_logits
+        sum_logits += new_quizzes.size(0) * average_logits
+        sum_nb_quizzes += new_quizzes.size(0)
 
         to_keep = new_quizzes[nb_correct == len(other_models) - 1]
+
+        if args.dirty_debug:
+            to_keep = new_quizzes
+
         log_string(
             f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)"
         )
+
         kept.append(to_keep)
 
     new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
@@ -374,7 +380,7 @@ def create_quizzes(
         log_string,
     )
 
-    return sum_logits / new_quizzes.size(0)
+    return sum_logits / sum_nb_quizzes
 
 
 ######################################################################
@@ -408,7 +414,7 @@ accuracy_to_make_quizzes = 0.975
 nb_new_quizzes_for_train = 1000
 nb_new_quizzes_for_test = 100
 
-if args.check:
+if args.dirty_debug:
     accuracy_to_make_quizzes = 0.0
     nb_new_quizzes_for_train = 100
     nb_new_quizzes_for_test = 10
index 3e63567..ab4ccbc 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -279,13 +279,12 @@ class MyGPT(nn.Module):
         self,
         input,
         ar_mask,
+        summed_logits,
         temperature=1.0,
         deterministic_synthesis=False,
         forbidden_tokens=None,
         forced_biases=None,
     ):
-        sum_logits = 0
-
         to_generate = (ar_mask.sum(0) > 0).nonzero()
 
         if to_generate.min() > 0:
@@ -297,7 +296,7 @@ class MyGPT(nn.Module):
 
             logits = output[:, s]
 
-            logits = logits.log_softmax(dim=1) / temperature
+            logits = (logits / temperature).log_softmax(dim=-1)
 
             if forbidden_tokens is not None:
                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
@@ -306,18 +305,17 @@ class MyGPT(nn.Module):
                 logits = logits + forced_biases[None, :]
 
             if deterministic_synthesis:
-                t_next = logits.argmax(1)
+                t_next = logits.argmax(-1)
             else:
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 t_next = dist.sample()
-                sum_logits += logits.log_softmax(dim=1)[
-                    torch.arange(t_next.size(0)), t_next
-                ].sum()
+                if summed_logits is not None:
+                    summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
+                        dim=-1
+                    )
 
             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
-        return sum_logits
-
     def record_attention(self, v=True):
         for m in self.modules():
             if isinstance(m, QKVAttention):
index 2a1833d..39372f3 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -22,6 +22,7 @@ def masked_inplace_autoregression(
     batch_size,
     input,
     ar_mask,
+    summed_logits,
     temperature,
     deterministic_synthesis,
     forbidden_tokens=None,
@@ -41,16 +42,15 @@ def masked_inplace_autoregression(
             total=(input.size(0) + batch_size - 1) // batch_size,
         )
 
-    sum_logits = 0
-
     with torch.autograd.no_grad():
         t = model.training
         model.eval()
 
         for input, ar_mask in batches:
-            sum_logits += model.masked_inplace_autoregression(
+            model.masked_inplace_autoregression(
                 input=input,
                 ar_mask=ar_mask,
+                summed_logits=summed_logits,
                 temperature=temperature,
                 deterministic_synthesis=deterministic_synthesis,
                 forbidden_tokens=forbidden_tokens,
@@ -59,8 +59,6 @@ def masked_inplace_autoregression(
 
         model.train(t)
 
-    return sum_logits
-
 
 ######################################################################
 
@@ -180,6 +178,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
+                summed_logits=None,
                 temperature=1.0,
                 deterministic_synthesis=deterministic_synthesis,
                 progress_bar_desc=None,
@@ -219,6 +218,7 @@ class World(Task):
             batch_size=self.batch_size,
             input=result,
             ar_mask=ar_mask,
+            summed_logits=None,
             temperature=1.0,
             deterministic_synthesis=deterministic_synthesis,
             progress_bar_desc=None,
@@ -266,23 +266,27 @@ class World(Task):
         )
 
         ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+        summed_logits = torch.empty(nb, device=self.device)
 
         temperature = 1
         d_temperature = 1
 
         while True:
-            sum_logits = masked_inplace_autoregression(
+            summed_logits[...] = 0
+
+            masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
                 input=quizzes,
                 ar_mask=ar_mask,
+                summed_logits=summed_logits,
                 temperature=temperature,
                 deterministic_synthesis=False,
                 progress_bar_desc="creating quizzes",
                 device=self.device,
             )
 
-            average_logits = sum_logits / quizzes.size(0)
+            average_logits = summed_logits.mean()
 
             logger(f"{average_logits=} {desired_average_logits=}")
 
@@ -290,14 +294,16 @@ class World(Task):
                 break
 
             # Oh man that's ugly
-            if average_logits > desired_average_logits:
+            if average_logits < desired_average_logits:
                 if d_temperature < 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
-            else:
+            elif average_logits > desired_average_logits * 0.95:
                 if d_temperature > 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
+            else:
+                break
 
             logger(f"chaging temperature to {temperature}")
 
@@ -329,6 +335,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
+                summed_logits=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving quizzes",
@@ -344,6 +351,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=reverse_result,
                 ar_mask=ar_mask,
+                summed_logits=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving reversed quizzes",
@@ -363,4 +371,4 @@ class World(Task):
         # for k in nb_correct:
         # f.write(f"{k}\n")
 
-        return quizzes, nb_correct.sum(dim=0), sum_logits
+        return quizzes, nb_correct.sum(dim=0), summed_logits.mean()