From 35a16ac34a3f1af05323a9cb3823fbcfd74035a4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 24 Jun 2024 21:23:26 +0200 Subject: [PATCH] Update. --- main.py | 24 +++++++++++++++--------- mygpt.py | 16 +++++++--------- tasks.py | 28 ++++++++++++++++++---------- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/main.py b/main.py index ee4e9e5..8033836 100755 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--check", action="store_true", default=False) +parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -182,9 +182,9 @@ for n in vars(args): ###################################################################### -if args.check: - args.nb_train_samples = 25000 - args.nb_test_samples = 1000 +if args.dirty_debug: + args.nb_train_samples = 2500 + args.nb_test_samples = 100 if args.physical_batch_size is None: args.physical_batch_size = args.batch_size @@ -339,12 +339,12 @@ def create_quizzes( ): kept = [] - sum_logits = 0 + sum_logits, sum_nb_quizzes = 0, 0 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: nb_to_generate = 4 * (nb_for_train + nb_for_test) - new_quizzes, nb_correct, _sum_logits = task.create_new_quizzes( + new_quizzes, nb_correct, average_logits = task.create_new_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, @@ -354,12 +354,18 @@ def create_quizzes( desired_average_logits=desired_average_logits, ) - sum_logits += _sum_logits + sum_logits += new_quizzes.size(0) * average_logits + sum_nb_quizzes += new_quizzes.size(0) to_keep = new_quizzes[nb_correct == len(other_models) - 1] + + if args.dirty_debug: + to_keep = new_quizzes + log_string( f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)" ) + kept.append(to_keep) new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] @@ -374,7 +380,7 @@ def create_quizzes( log_string, ) - return sum_logits / new_quizzes.size(0) + return sum_logits / sum_nb_quizzes ###################################################################### @@ -408,7 +414,7 @@ accuracy_to_make_quizzes = 0.975 nb_new_quizzes_for_train = 1000 nb_new_quizzes_for_test = 100 -if args.check: +if args.dirty_debug: accuracy_to_make_quizzes = 0.0 nb_new_quizzes_for_train = 100 nb_new_quizzes_for_test = 10 diff --git a/mygpt.py b/mygpt.py index 3e63567..ab4ccbc 100755 --- a/mygpt.py +++ b/mygpt.py @@ -279,13 +279,12 @@ class MyGPT(nn.Module): self, input, ar_mask, + summed_logits, temperature=1.0, deterministic_synthesis=False, forbidden_tokens=None, forced_biases=None, ): - sum_logits = 0 - to_generate = (ar_mask.sum(0) > 0).nonzero() if to_generate.min() > 0: @@ -297,7 +296,7 @@ class MyGPT(nn.Module): logits = output[:, s] - logits = logits.log_softmax(dim=1) / temperature + logits = (logits / temperature).log_softmax(dim=-1) if forbidden_tokens is not None: logits = logits.masked_fill(forbidden_tokens, float("-inf")) @@ -306,18 +305,17 @@ class MyGPT(nn.Module): logits = logits + forced_biases[None, :] if deterministic_synthesis: - t_next = logits.argmax(1) + t_next = logits.argmax(-1) else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - sum_logits += logits.log_softmax(dim=1)[ - torch.arange(t_next.size(0)), t_next - ].sum() + if summed_logits is not None: + summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum( + dim=-1 + ) input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] - return sum_logits - def record_attention(self, v=True): for m in self.modules(): if isinstance(m, QKVAttention): diff --git a/tasks.py b/tasks.py index 2a1833d..39372f3 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + summed_logits, temperature, deterministic_synthesis, forbidden_tokens=None, @@ -41,16 +42,15 @@ def masked_inplace_autoregression( total=(input.size(0) + batch_size - 1) // batch_size, ) - sum_logits = 0 - with torch.autograd.no_grad(): t = model.training model.eval() for input, ar_mask in batches: - sum_logits += model.masked_inplace_autoregression( + model.masked_inplace_autoregression( input=input, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=deterministic_synthesis, forbidden_tokens=forbidden_tokens, @@ -59,8 +59,6 @@ def masked_inplace_autoregression( model.train(t) - return sum_logits - ###################################################################### @@ -180,6 +178,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -219,6 +218,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -266,23 +266,27 @@ class World(Task): ) ar_mask = torch.full(quizzes.size(), 1, device=self.device) + summed_logits = torch.empty(nb, device=self.device) temperature = 1 d_temperature = 1 while True: - sum_logits = masked_inplace_autoregression( + summed_logits[...] = 0 + + masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=quizzes, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) - average_logits = sum_logits / quizzes.size(0) + average_logits = summed_logits.mean() logger(f"{average_logits=} {desired_average_logits=}") @@ -290,14 +294,16 @@ class World(Task): break # Oh man that's ugly - if average_logits > desired_average_logits: + if average_logits < desired_average_logits: if d_temperature < 0: d_temperature *= -0.5 temperature += d_temperature - else: + elif average_logits > desired_average_logits * 0.95: if d_temperature > 0: d_temperature *= -0.5 temperature += d_temperature + else: + break logger(f"chaging temperature to {temperature}") @@ -329,6 +335,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", @@ -344,6 +351,7 @@ class World(Task): batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", @@ -363,4 +371,4 @@ class World(Task): # for k in nb_correct: # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0), sum_logits + return quizzes, nb_correct.sum(dim=0), summed_logits.mean() -- 2.20.1