From 4de6ddc3d455dcd98e5472c29ed8308d60fc9113 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 25 Jul 2024 06:01:29 +0200 Subject: [PATCH] Update. --- grids.py | 2 +- quiz_machine.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/grids.py b/grids.py index 25bbc80..f6129e9 100755 --- a/grids.py +++ b/grids.py @@ -139,7 +139,7 @@ class Grids(problem.Problem): def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): if torch.is_tensor(quizzes): - return self.reconfigure([quizzes])[0] + return self.reconfigure([quizzes], struct=struct)[0] S = self.height * self.width result = [x.new(x.size()) for x in quizzes] diff --git a/quiz_machine.py b/quiz_machine.py index 8f14fa0..4615e3a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -131,7 +131,7 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.configurations = [ + self.train_struct = [ ("A", "f_A", "B", "f_B"), # The standard order ("f_A", "A", "f_B", "B"), # The reverse order for validation ("f_B", "f_A", "A", "B"), # The synthesis order @@ -183,8 +183,12 @@ class QuizMachine: ###################################################################### + def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): + assert struct in self.train_struct + return self.problem.make_ar_mask(quizzes, struct, mask) + def predict(self, model, quizzes, struct, mask): - ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) + ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) seq_logproba = torch.empty(quizzes.size(0), device=self.device) @@ -250,7 +254,7 @@ class QuizMachine: predicted_parts = predicted_parts[:128] correct_parts = correct_parts[:128] - self.problem.reconfigure( + result, predicted_parts, correct_parts = self.problem.reconfigure( [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B") ) @@ -281,11 +285,11 @@ class QuizMachine: model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples) self.randomize_configuations_inplace( - model.train_w_quizzes, configurations=self.configurations + model.train_w_quizzes, configurations=self.train_struct ) self.randomize_configuations_inplace( - model.test_w_quizzes, configurations=self.configurations + model.test_w_quizzes, configurations=self.train_struct ) ###################################################################### @@ -318,7 +322,7 @@ class QuizMachine: ) self.randomize_configuations_inplace( - model.train_w_quizzes, configurations=self.configurations + model.train_w_quizzes, configurations=self.train_struct ) ###################################################################### @@ -356,7 +360,7 @@ class QuizMachine: c_quizzes.split(self.batch_size), logproba.split(self.batch_size) ): input = input.to(self.device) - ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123") + ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123") output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( @@ -397,7 +401,7 @@ class QuizMachine: # A, f(A), B | f(B) result = c_quizzes.clone() - ar_mask = self.problem.make_ar_mask( + ar_mask = self.make_ar_mask( result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) ) @@ -418,7 +422,7 @@ class QuizMachine: # f(A), A, f(B) | B result = reversed_c_quizzes.clone() - ar_mask = self.problem.make_ar_mask( + ar_mask = self.make_ar_mask( result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1) ) @@ -462,7 +466,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask( + ar_mask=self.make_ar_mask( c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0) ), seq_logproba=seq_logproba, @@ -475,7 +479,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask( + ar_mask=self.make_ar_mask( c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1) ), seq_logproba=seq_logproba, @@ -490,7 +494,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask( + ar_mask=self.make_ar_mask( c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) ), seq_logproba=seq_logproba, -- 2.20.1