From b3392c295bdb75140916e2db70efc6fa50962f63 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 24 Jun 2024 20:38:07 +0200 Subject: [PATCH] Update. --- main.py | 17 ++++++++-------- mygpt.py | 8 ++++++-- tasks.py | 61 ++++++++++++++++++++++++++++---------------------------- 3 files changed, 46 insertions(+), 40 deletions(-) diff --git a/main.py b/main.py index 2afe61b..ee4e9e5 100755 --- a/main.py +++ b/main.py @@ -183,8 +183,8 @@ for n in vars(args): ###################################################################### if args.check: - args.nb_train_samples = 2500 - args.nb_test_samples = 100 + args.nb_train_samples = 25000 + args.nb_test_samples = 1000 if args.physical_batch_size is None: args.physical_batch_size = args.batch_size @@ -338,11 +338,13 @@ def create_quizzes( desired_average_logits=None, ): kept = [] - nb_generated_tokens, sum_logits = 0, 0 + + sum_logits = 0 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: nb_to_generate = 4 * (nb_for_train + nb_for_test) - new_quizzes, nb_correct, average_logits = task.create_new_quizzes( + + new_quizzes, nb_correct, _sum_logits = task.create_new_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, @@ -352,8 +354,7 @@ def create_quizzes( desired_average_logits=desired_average_logits, ) - nb_generated_tokens += new_quizzes.numel() - sum_logits += average_logits * new_quizzes.numel() + sum_logits += _sum_logits to_keep = new_quizzes[nb_correct == len(other_models) - 1] log_string( @@ -373,7 +374,7 @@ def create_quizzes( log_string, ) - return sum_logits / nb_generated_tokens + return sum_logits / new_quizzes.size(0) ###################################################################### @@ -409,7 +410,7 @@ nb_new_quizzes_for_test = 100 if args.check: accuracy_to_make_quizzes = 0.0 - nb_new_quizzes_for_train = 10 + nb_new_quizzes_for_train = 100 nb_new_quizzes_for_test = 10 desired_average_logits = None diff --git a/mygpt.py b/mygpt.py index c58bea1..3e63567 100755 --- a/mygpt.py +++ b/mygpt.py @@ -285,16 +285,19 @@ class MyGPT(nn.Module): forced_biases=None, ): sum_logits = 0 + to_generate = (ar_mask.sum(0) > 0).nonzero() + if to_generate.min() > 0: self( BracketedSequence(input, 0, to_generate.min()) ) # Needed to initialize the model's cache for s in range(to_generate.min(), to_generate.max() + 1): output = self(BracketedSequence(input, s, 1)).x + logits = output[:, s] - logits = logits.log_softmax(dim=-1) / temperature + logits = logits.log_softmax(dim=1) / temperature if forbidden_tokens is not None: logits = logits.masked_fill(forbidden_tokens, float("-inf")) @@ -307,9 +310,10 @@ class MyGPT(nn.Module): else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - sum_logits += logits.log_softmax(dim=-1)[ + sum_logits += logits.log_softmax(dim=1)[ torch.arange(t_next.size(0)), t_next ].sum() + input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] return sum_logits diff --git a/tasks.py b/tasks.py index 64fe967..b3b56ad 100755 --- a/tasks.py +++ b/tasks.py @@ -41,12 +41,12 @@ def masked_inplace_autoregression( total=(input.size(0) + batch_size - 1) // batch_size, ) + sum_logits = 0 + with torch.autograd.no_grad(): t = model.training model.eval() - sum_logits = 0 - for input, ar_mask in batches: sum_logits += model.masked_inplace_autoregression( input=input, @@ -59,7 +59,7 @@ def masked_inplace_autoregression( model.train(t) - return sum_logits + return sum_logits ###################################################################### @@ -264,31 +264,14 @@ class World(Task): quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(quizzes.size(), 1, device=self.device) - sum_logits = masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=quizzes, - ar_mask=ar_mask, - temperature=1.0, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) - - # Should not be necessary though, the autoregression is done - # in eval mode - sum_logits = sum_logits.detach() - - average_logits = sum_logits / quizzes.numel() + ar_mask = torch.full(quizzes.size(), 1, device=self.device) - # It's a bit brutal to do it twice, we should probably have a - # moving average and apply it right away + temperature = 1 + d_temperature = 1 - if desired_average_logits is not None: - temperature = average_logits / desired_average_logits - masked_inplace_autoregression( + while True: + sum_logits = masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=quizzes, @@ -299,6 +282,24 @@ class World(Task): device=self.device, ) + average_logits = sum_logits / quizzes.size(0) + + logger(f"{average_logits=} {desired_average_logits=}") + + if desired_average_logits is None: + break + + # Oh man that's ugly + if average_logits > desired_average_logits: + if d_temperature < 0: + d_temperature *= -0.5 + temperature += d_temperature + else: + if d_temperature > 0: + d_temperature *= -0.5 + temperature += d_temperature + logger(f"chaging temperature to {temperature}") + ############################################################### # Create the reverse quizzes @@ -356,9 +357,9 @@ class World(Task): nb_correct = torch.cat(nb_correct, dim=0) - filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - with open(filename, "w") as f: - for k in nb_correct: - f.write(f"{k}\n") + # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + # with open(filename, "w") as f: + # for k in nb_correct: + # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0), average_logits + return quizzes, nb_correct.sum(dim=0), sum_logits -- 2.39.5