From 688f96809d9a64338962264c0afcd86378183ff3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 11 Aug 2024 10:25:49 +0200 Subject: [PATCH] Update. --- grids.py | 4 +++- main.py | 15 ++++----------- quiz_machine.py | 35 +++++++++++++++++++---------------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/grids.py b/grids.py index 1a31a36..0564f3b 100755 --- a/grids.py +++ b/grids.py @@ -218,7 +218,9 @@ class Grids(problem.Problem): dim=1 ).values - def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): + def make_quiz_mask( + self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) + ): assert self.check_structure(quizzes, struct) ar_mask = quizzes.new_zeros(quizzes.size()) diff --git a/main.py b/main.py index f4691cb..a1389c1 100755 --- a/main.py +++ b/main.py @@ -502,18 +502,11 @@ def model_transformer_cold(model): # pass -warnings.warn("*********** novel procedure!!! **********", RuntimeWarning) - c_quizzes_procedure = [ - # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 0), model_transformer_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), - # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), ] ###################################################################### @@ -768,7 +761,7 @@ def generate_c_quizzes_with_generator(generator, quiz_machine, nb): struct = ("A", "f_A", "B", "f_B") c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct) - ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1)) + ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1)) i = F.one_hot( torch.randint(args.nb_gpts, (c_quizzes.size(0),)), diff --git a/quiz_machine.py b/quiz_machine.py index 34abd34..ceb523d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -82,19 +82,20 @@ class QuizMachine: self.prompt_noise = prompt_noise # struct, mask_generate, mask_noise, mask_loss - self.understood_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + self.train_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] - self.test_structures = [ - self.understood_structures[0], - self.understood_structures[2], - self.understood_structures[4], - ] + self.test_structures = self.train_structures self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] @@ -185,13 +186,13 @@ class QuizMachine: quizzes, from_w = quizzes[i], from_w[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, _, _, _ in self.understood_structures] + quizzes, structs=[s for s, _, _, _ in self.train_structures] ) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) if self.prompt_noise > 0.0: - for struct, _, mask_noise, mask_loss in self.understood_structures: + for struct, _, mask_noise, mask_loss in self.train_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( @@ -206,7 +207,7 @@ class QuizMachine: ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in [s for s, _, _, _ in self.understood_structures] + assert struct in [s for s, _, _, _ in self.train_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -343,7 +344,7 @@ class QuizMachine: models_for_validation, c_quizzes, struct, - mask_value, + mask_loss, mask_noise=None, device=None, ): @@ -373,13 +374,15 @@ class QuizMachine: seq_logproba.split(self.batch_size), ): input = input.to(device) - ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value) + quiz_mask_loss = self.make_ar_mask( + input, struct=struct, mask=mask_loss + ) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( output.transpose(1, 2), input, reduction="none" ) - * ar_mask + * quiz_mask_loss ).sum(dim=1) model.train(t) -- 2.39.5