task,
     nb_for_train=1000,
     nb_for_test=100,
-    desired_average_logits=None,
+    min_ave_seq_logproba=None,
 ):
     kept = []
 
     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
         nb_to_generate = 4 * (nb_for_train + nb_for_test)
 
-        new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes(
+        new_c_quizzes, nb_correct, ave_seq_logproba = task.create_c_quizzes(
             n_epoch=n_epoch,
             result_dir=args.result_dir,
             logger=log_string,
             nb=nb_to_generate,
             model=model,
             other_models=other_models,
-            desired_average_logits=desired_average_logits,
+            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
-        sum_logits += new_c_quizzes.size(0) * average_logits
+        sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += new_c_quizzes.size(0)
 
         to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
 
 ######################################################################
 
-desired_average_logits = None
+min_ave_seq_logproba = None
 
 for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
         other_models = models.copy()
         other_models.remove(model)
 
-        average_logits = create_c_quizzes(
+        ave_seq_logproba = create_c_quizzes(
             model,
             other_models,
             task,
             nb_for_train=nb_new_c_quizzes_for_train,
             nb_for_test=nb_new_c_quizzes_for_test,
-            desired_average_logits=desired_average_logits,
+            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
         # We keep the first average logits as a reference
-        if desired_average_logits is None:
-            desired_average_logits = average_logits
+        if min_ave_seq_logproba is None:
+            min_ave_seq_logproba = ave_seq_logproba
         else:
             log_string(
-                f"desired_average_logits {desired_average_logits} average_logits {average_logits}"
+                f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}"
             )
 
         # We update everyone
 
         self,
         input,
         ar_mask,
-        summed_logits,
+        seq_logproba,
         temperature=1.0,
         deterministic_synthesis=False,
         forbidden_tokens=None,
             else:
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 t_next = dist.sample()
-                if summed_logits is not None:
-                    summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
-                        dim=-1
-                    )
+
+            if seq_logproba is not None:
+                all_t = torch.arange(t_next.size(0))
+                seq_logproba += logits[all_t, t_next].sum(dim=-1)
 
             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
 
     batch_size,
     input,
     ar_mask,
-    summed_logits,
+    seq_logproba,
     temperature,
     deterministic_synthesis,
     forbidden_tokens=None,
             model.masked_inplace_autoregression(
                 input=input,
                 ar_mask=ar_mask,
-                summed_logits=summed_logits,
+                seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=deterministic_synthesis,
                 forbidden_tokens=forbidden_tokens,
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=deterministic_synthesis,
                 progress_bar_desc=None,
             batch_size=self.batch_size,
             input=result,
             ar_mask=ar_mask,
-            summed_logits=None,
+            seq_logproba=None,
             temperature=1.0,
             deterministic_synthesis=deterministic_synthesis,
             progress_bar_desc=None,
         nb,
         model,
         other_models,
-        desired_average_logits=None,
+        min_ave_seq_logproba=None,
     ):
         ###############################################################
         # Generate quizzes with model
         )
 
         ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
-        summed_logits = torch.empty(nb, device=self.device)
+        seq_logproba = torch.empty(nb, device=self.device)
 
         temperature = 1
         d_temperature = 1
 
         while True:
-            summed_logits[...] = 0
+            seq_logproba[...] = 0
 
             masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
                 input=c_quizzes,
                 ar_mask=ar_mask,
-                summed_logits=summed_logits,
+                seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=False,
                 progress_bar_desc="sampling c_quizzes",
                 device=self.device,
             )
 
-            average_logits = summed_logits.mean()
+            ave_seq_logproba = seq_logproba.mean()
 
-            logger(f"{average_logits=} {desired_average_logits=}")
+            logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
 
-            if desired_average_logits is None:
+            if min_ave_seq_logproba is None:
                 break
 
             # Oh man that's ugly
-            if average_logits < desired_average_logits * 1.1:
+            if ave_seq_logproba < min_ave_seq_logproba * 1.1:
                 if d_temperature > 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
-            elif average_logits > desired_average_logits:
+            elif ave_seq_logproba > min_ave_seq_logproba:
                 if d_temperature < 0:
                     d_temperature *= -0.5
                 temperature += d_temperature
                 batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving c_quizzes",
                 batch_size=self.batch_size,
                 input=reverse_result,
                 ar_mask=ar_mask,
-                summed_logits=None,
+                seq_logproba=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
                 progress_bar_desc="solving reversed c_quizzes",
 
         nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
 
-        # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
-        # with open(filename, "w") as f:
-        # for k in nb_correct:
-        # f.write(f"{k}\n")
-
-        return c_quizzes, nb_correct, summed_logits.mean()
+        return c_quizzes, nb_correct, seq_logproba.mean()