Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 09:51:04 +0000 (11:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 09:51:04 +0000 (11:51 +0200)
main.py
mygpt.py
tasks.py

diff --git a/main.py b/main.py
index ebecad8..2c759ec 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -350,7 +350,7 @@ def create_c_quizzes(
     task,
     nb_for_train=1000,
     nb_for_test=100,
-    desired_average_logits=None,
+    min_ave_seq_logproba=None,
 ):
     kept = []
 
@@ -359,17 +359,17 @@ def create_c_quizzes(
     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
         nb_to_generate = 4 * (nb_for_train + nb_for_test)
 
-        new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes(
+        new_c_quizzes, nb_correct, ave_seq_logproba = task.create_c_quizzes(
             n_epoch=n_epoch,
             result_dir=args.result_dir,
             logger=log_string,
             nb=nb_to_generate,
             model=model,
             other_models=other_models,
-            desired_average_logits=desired_average_logits,
+            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
-        sum_logits += new_c_quizzes.size(0) * average_logits
+        sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += new_c_quizzes.size(0)
 
         to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
@@ -425,7 +425,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-desired_average_logits = None
+min_ave_seq_logproba = None
 
 for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
@@ -462,21 +462,21 @@ for n_epoch in range(args.nb_epochs):
         other_models = models.copy()
         other_models.remove(model)
 
-        average_logits = create_c_quizzes(
+        ave_seq_logproba = create_c_quizzes(
             model,
             other_models,
             task,
             nb_for_train=nb_new_c_quizzes_for_train,
             nb_for_test=nb_new_c_quizzes_for_test,
-            desired_average_logits=desired_average_logits,
+            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
         # We keep the first average logits as a reference
-        if desired_average_logits is None:
-            desired_average_logits = average_logits
+        if min_ave_seq_logproba is None:
+            min_ave_seq_logproba = ave_seq_logproba
         else:
             log_string(
-                f"desired_average_logits {desired_average_logits} average_logits {average_logits}"
+                f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}"
             )
 
         # We update everyone
index ab4ccbc..809f790 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -279,7 +279,7 @@ class MyGPT(nn.Module):
         self,
         input,
         ar_mask,
-        summed_logits,
+        seq_logproba,
         temperature=1.0,
         deterministic_synthesis=False,
         forbidden_tokens=None,
@@ -309,10 +309,10 @@ class MyGPT(nn.Module):
             else:
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 t_next = dist.sample()
-                if summed_logits is not None:
-                    summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
-                        dim=-1
-                    )
+
+            if seq_logproba is not None:
+                all_t = torch.arange(t_next.size(0))
+                seq_logproba += logits[all_t, t_next].sum(dim=-1)
 
             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
index 43f7d53..a522728 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -22,7 +22,7 @@ def masked_inplace_autoregression(
     batch_size,
     input,
     ar_mask,
-    summed_logits,
+    seq_logproba,
     temperature,
     deterministic_synthesis,
     forbidden_tokens=None,
@@ -50,7 +50,7 @@ def masked_inplace_autoregression(
             model.masked_inplace_autoregression(
                 input=input,
                 ar_mask=ar_mask,
-                summed_logits=summed_logits,
+                seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=deterministic_synthesis,
                 forbidden_tokens=forbidden_tokens,
@@ -184,7 +184,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=deterministic_synthesis,
                 progress_bar_desc=None,
@@ -224,7 +224,7 @@ class World(Task):
             batch_size=self.batch_size,
             input=result,
             ar_mask=ar_mask,
-            summed_logits=None,
+            seq_logproba=None,
             temperature=1.0,
             deterministic_synthesis=deterministic_synthesis,
             progress_bar_desc=None,
@@ -262,7 +262,7 @@ class World(Task):
         nb,
         model,
         other_models,
-        desired_average_logits=None,
+        min_ave_seq_logproba=None,
     ):
         ###############################################################
         # Generate quizzes with model
@@ -272,39 +272,39 @@ class World(Task):
         )
 
         ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
-        summed_logits = torch.empty(nb, device=self.device)
+        seq_logproba = torch.empty(nb, device=self.device)
 
         temperature = 1
         d_temperature = 1
 
         while True:
-            summed_logits[...] = 0
+            seq_logproba[...] = 0
 
             masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
                 input=c_quizzes,
                 ar_mask=ar_mask,
-                summed_logits=summed_logits,
+                seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=False,
                 progress_bar_desc="sampling c_quizzes",
                 device=self.device,
             )
 
-            average_logits = summed_logits.mean()
+            ave_seq_logproba = seq_logproba.mean()
 
-            logger(f"{average_logits=} {desired_average_logits=}")
+            logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
 
-            if desired_average_logits is None:
+            if min_ave_seq_logproba is None:
                 break
 
             # Oh man that's ugly
-            if average_logits < desired_average_logits * 1.1:
+            if ave_seq_logproba < min_ave_seq_logproba * 1.1:
                 if d_temperature > 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
-            elif average_logits > desired_average_logits:
+            elif ave_seq_logproba > min_ave_seq_logproba:
                 if d_temperature < 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
@@ -341,7 +341,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving c_quizzes",
@@ -357,7 +357,7 @@ class World(Task):
                 batch_size=self.batch_size,
                 input=reverse_result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving reversed c_quizzes",
@@ -372,9 +372,4 @@ class World(Task):
 
         nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
 
-        # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
-        # with open(filename, "w") as f:
-        # for k in nb_correct:
-        # f.write(f"{k}\n")
-
-        return c_quizzes, nb_correct, summed_logits.mean()
+        return c_quizzes, nb_correct, seq_logproba.mean()