From 60d829ba77c9769009d3d5a93a50d23c532d019a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 11:51:04 +0200 Subject: [PATCH] Update. --- main.py | 20 ++++++++++---------- mygpt.py | 10 +++++----- tasks.py | 37 ++++++++++++++++--------------------- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/main.py b/main.py index ebecad8..2c759ec 100755 --- a/main.py +++ b/main.py @@ -350,7 +350,7 @@ def create_c_quizzes( task, nb_for_train=1000, nb_for_test=100, - desired_average_logits=None, + min_ave_seq_logproba=None, ): kept = [] @@ -359,17 +359,17 @@ def create_c_quizzes( while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: nb_to_generate = 4 * (nb_for_train + nb_for_test) - new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes( + new_c_quizzes, nb_correct, ave_seq_logproba = task.create_c_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, nb=nb_to_generate, model=model, other_models=other_models, - desired_average_logits=desired_average_logits, + min_ave_seq_logproba=min_ave_seq_logproba, ) - sum_logits += new_c_quizzes.size(0) * average_logits + sum_logits += new_c_quizzes.size(0) * ave_seq_logproba sum_nb_c_quizzes += new_c_quizzes.size(0) to_keep = new_c_quizzes[nb_correct == len(other_models) - 1] @@ -425,7 +425,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -desired_average_logits = None +min_ave_seq_logproba = None for n_epoch in range(args.nb_epochs): log_string(f"--- epoch {n_epoch} ----------------------------------------") @@ -462,21 +462,21 @@ for n_epoch in range(args.nb_epochs): other_models = models.copy() other_models.remove(model) - average_logits = create_c_quizzes( + ave_seq_logproba = create_c_quizzes( model, other_models, task, nb_for_train=nb_new_c_quizzes_for_train, nb_for_test=nb_new_c_quizzes_for_test, - desired_average_logits=desired_average_logits, + min_ave_seq_logproba=min_ave_seq_logproba, ) # We keep the first average logits as a reference - if desired_average_logits is None: - desired_average_logits = average_logits + if min_ave_seq_logproba is None: + min_ave_seq_logproba = ave_seq_logproba else: log_string( - f"desired_average_logits {desired_average_logits} average_logits {average_logits}" + f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}" ) # We update everyone diff --git a/mygpt.py b/mygpt.py index ab4ccbc..809f790 100755 --- a/mygpt.py +++ b/mygpt.py @@ -279,7 +279,7 @@ class MyGPT(nn.Module): self, input, ar_mask, - summed_logits, + seq_logproba, temperature=1.0, deterministic_synthesis=False, forbidden_tokens=None, @@ -309,10 +309,10 @@ class MyGPT(nn.Module): else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - if summed_logits is not None: - summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum( - dim=-1 - ) + + if seq_logproba is not None: + all_t = torch.arange(t_next.size(0)) + seq_logproba += logits[all_t, t_next].sum(dim=-1) input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] diff --git a/tasks.py b/tasks.py index 43f7d53..a522728 100755 --- a/tasks.py +++ b/tasks.py @@ -22,7 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, - summed_logits, + seq_logproba, temperature, deterministic_synthesis, forbidden_tokens=None, @@ -50,7 +50,7 @@ def masked_inplace_autoregression( model.masked_inplace_autoregression( input=input, ar_mask=ar_mask, - summed_logits=summed_logits, + seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=deterministic_synthesis, forbidden_tokens=forbidden_tokens, @@ -184,7 +184,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, - summed_logits=None, + seq_logproba=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -224,7 +224,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, - summed_logits=None, + seq_logproba=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -262,7 +262,7 @@ class World(Task): nb, model, other_models, - desired_average_logits=None, + min_ave_seq_logproba=None, ): ############################################################### # Generate quizzes with model @@ -272,39 +272,39 @@ class World(Task): ) ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) - summed_logits = torch.empty(nb, device=self.device) + seq_logproba = torch.empty(nb, device=self.device) temperature = 1 d_temperature = 1 while True: - summed_logits[...] = 0 + seq_logproba[...] = 0 masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=c_quizzes, ar_mask=ar_mask, - summed_logits=summed_logits, + seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, progress_bar_desc="sampling c_quizzes", device=self.device, ) - average_logits = summed_logits.mean() + ave_seq_logproba = seq_logproba.mean() - logger(f"{average_logits=} {desired_average_logits=}") + logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if desired_average_logits is None: + if min_ave_seq_logproba is None: break # Oh man that's ugly - if average_logits < desired_average_logits * 1.1: + if ave_seq_logproba < min_ave_seq_logproba * 1.1: if d_temperature > 0: d_temperature *= -0.5 temperature += d_temperature - elif average_logits > desired_average_logits: + elif ave_seq_logproba > min_ave_seq_logproba: if d_temperature < 0: d_temperature *= -0.5 temperature += d_temperature @@ -341,7 +341,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, - summed_logits=None, + seq_logproba=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving c_quizzes", @@ -357,7 +357,7 @@ class World(Task): batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, - summed_logits=None, + seq_logproba=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed c_quizzes", @@ -372,9 +372,4 @@ class World(Task): nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0) - # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - # with open(filename, "w") as f: - # for k in nb_correct: - # f.write(f"{k}\n") - - return c_quizzes, nb_correct, summed_logits.mean() + return c_quizzes, nb_correct, seq_logproba.mean() -- 2.20.1