From a3ba9c8985926df4bbcca26cdde51dc55aa4d442 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 22:21:03 +0200 Subject: [PATCH] Update. --- main.py | 80 +++++++++++++++++-------------------------------- quiz_machine.py | 57 ++++++++++++++++++++--------------- 2 files changed, 60 insertions(+), 77 deletions(-) diff --git a/main.py b/main.py index 41efc86..ca1e9b5 100755 --- a/main.py +++ b/main.py @@ -314,7 +314,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 - for input in quiz_machine.batches(model, split="test"): + full_input, _ = quiz_machine.data_input(model, split="test") + src = full_input.split(args.batch_size) + + for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): input = input.to(local_device) bs = model(mygpt.BracketedSequence(input)) @@ -345,16 +348,29 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 - for input in quiz_machine.batches(model, split="train"): + hard_w_quizzes = [] + + full_input, full_from_w = quiz_machine.data_input(model, split="train") + src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size)) + + for input, from_w in tqdm.tqdm(src, dynamic_ncols=True, desc="training"): input = input.to(local_device) if nb_train_samples % args.batch_size == 0: optimizer.zero_grad() output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + loss_per_token = F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + loss = loss_per_token.mean() acc_train_loss += loss.item() * input.size(0) + loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) + hard_w_quizzes.append( + (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) + ) + nb_train_samples += input.size(0) loss.backward() @@ -368,6 +384,13 @@ def one_epoch(model, quiz_machine, local_device=main_device): run_tests(model, quiz_machine, deterministic_synthesis=False) + threshold = torch.cat([x[1] for x in hard_w_quizzes], dim=0).sort().values + threshold = threshold[threshold.size(0) // 2] + + model.hard_w_quizzes = torch.cat( + [x[0][x[1] >= threshold] for x in hard_w_quizzes], dim=0 + ) + model.to(main_device) @@ -443,7 +466,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) validated_quizzes = torch.cat(recorded, dim=0) @@ -542,54 +565,6 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -# Compute the entropy of the training tokens - -token_count = 0 -for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"): - token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum( - (0, 1) - ) -token_probas = token_count / token_count.sum() -entropy = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(entropy) - -###################################################################### -# A bit of paranoia never hurts - -if args.max_percents_of_test_in_train >= 0: - - def subsets_as_tuples(batches, cs): - s = set() - for batch in batches: - for x in batch: - s.add(tuple([v.item() for v in x])) - if len(s) == cs: - yield s - s = set() - yield s - - nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples( - quiz_machine.batches(models[0], split="test", desc="test-check"), 25000 - ): - in_train = set() - for train_subset in subsets_as_tuples( - quiz_machine.batches(models[0], split="train", desc="train-check"), 25000 - ): - in_train.update(test_subset.intersection(train_subset)) - nb_in_train += len(in_train) - nb_test += len(test_subset) - - log_string( - f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" - ) - - assert ( - nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 - ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" - -###################################################################### - if args.nb_new_c_quizzes_for_train is None: args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100 @@ -679,7 +654,6 @@ for n_epoch in range(args.nb_epochs): for model in weakest_models: quiz_machine.renew_w_quizzes( model=model, - nb=args.nb_train_samples, for_train=True, forward_only=args.forward_only, ) diff --git a/quiz_machine.py b/quiz_machine.py index faa640e..32b3f7e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -313,7 +313,7 @@ class QuizMachine: ###################################################################### - def batches(self, model, split="train", desc=None): + def data_input(self, model, split="train"): assert split in {"train", "test"} with self.LOCK_C_QUIZZES: @@ -335,24 +335,18 @@ class QuizMachine: ] w_quizzes = w_quizzes[i] - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + from_w = torch.arange( + quizzes.size(0), device=quizzes.device + ) < w_quizzes.size(0) + i = torch.randperm(quizzes.size(0), device=quizzes.device) - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 - - # Shuffle - input = input[torch.randperm(input.size(0))] + return quizzes[i], type_w[i] - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch + else: + return w_quizzes, torch.full( + (w_quizzes.size(0),), True, device=w_quizzes.device + ) ###################################################################### @@ -441,14 +435,29 @@ class QuizMachine: ###################################################################### - def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False): + def renew_w_quizzes(self, model, for_train=True, forward_only=False): input = model.train_w_quizzes if for_train else model.test_w_quizzes - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - fresh_w_quizzes = self.generate_token_sequences(nb) - if not forward_only: - self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to("cpu") + + if for_train and hasattr(model, "hard_w_quizzes"): + self.logger( + f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" + ) + if model.hard_w_quizzes.size(0) >= input.size(0): + input[...] = model.hard_w_quizzes[ + torch.randperm(hard_w_quizzes.size(0))[input.size(0)] + ] + else: + input[...] = torch.cat( + [ + model.hard_w_quizzes, + self.generate_token_sequences( + input.size(0) - model.hard_w_quizzes.size(0) + ), + ], + dim=0, + ) + else: + input[...] = self.generate_token_sequences(input.size(0)) ###################################################################### -- 2.39.5