From 1e09e9362c049054f25588bc699d761af5b715c9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 11 Aug 2024 00:58:46 +0200 Subject: [PATCH] Update. --- main.py | 34 ++++++++++++++++++++++------------ quiz_machine.py | 33 ++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 0670262..f4691cb 100755 --- a/main.py +++ b/main.py @@ -390,13 +390,26 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 - full_input, _ = quiz_machine.data_input(model, split="test") - src = full_input.split(args.batch_size) + full_input, full_mask_loss = quiz_machine.data_input(model, split="test") + src = zip( + full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + ) - for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): + for input, mask_loss in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="test", + total=full_input.size(0) // args.batch_size, + ): input = input.to(local_device) + mask_loss = mask_loss.to(local_device) + targets = input + output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * mask_loss).mean() acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -426,16 +439,17 @@ def one_epoch(model, quiz_machine, local_device=main_device): hard_w_quizzes = [] - full_input, full_from_w = quiz_machine.data_input(model, split="train") - src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size)) + full_input, full_mask_loss = quiz_machine.data_input(model, split="train") + src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) - for input, from_w in tqdm.tqdm( + for input, mask_loss in tqdm.tqdm( src, dynamic_ncols=True, desc="training", total=full_input.size(0) // args.batch_size, ): input = input.to(local_device) + mask_loss = mask_loss.to(local_device) if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() @@ -446,14 +460,10 @@ def one_epoch(model, quiz_machine, local_device=main_device): loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) - loss = loss_per_token.mean() + model.loss + loss = (loss_per_token * mask_loss).mean() + model.loss acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - if from_w.any(): - hard_w_quizzes.append( - (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) - ) nb_train_samples += input.size(0) diff --git a/quiz_machine.py b/quiz_machine.py index daa9bbf..34abd34 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -81,13 +81,19 @@ class QuizMachine: self.answer_len = None self.prompt_noise = prompt_noise - # struct, mask_generate, mask_noise + # struct, mask_generate, mask_noise, mask_loss self.understood_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + ] + + self.test_structures = [ + self.understood_structures[0], + self.understood_structures[2], + self.understood_structures[4], ] self.LOCK_C_QUIZZES = threading.Lock() @@ -179,23 +185,28 @@ class QuizMachine: quizzes, from_w = quizzes[i], from_w[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, _, _ in self.understood_structures] + quizzes, structs=[s for s, _, _, _ in self.understood_structures] ) + quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) + if self.prompt_noise > 0.0: - for struct, _, mask_noise in self.understood_structures: + for struct, _, mask_noise, mask_loss in self.understood_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise ) + quiz_mask_loss[i] = self.make_ar_mask( + quizzes=quizzes[i], struct=struct, mask=mask_loss + ) - return quizzes, from_w + return quizzes, quiz_mask_loss ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in [s for s, _, _ in self.understood_structures] + assert struct in [s for s, _, _, _ in self.understood_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -229,7 +240,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask_generate, _ in self.understood_structures: + for struct, mask_generate, _, _ in self.test_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict( -- 2.39.5