From 0ffb70add2f70d6585ab871f835fa73e25c7a4ad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 11 Aug 2024 10:26:07 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index ceb523d..92da03d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -198,7 +198,7 @@ class QuizMachine: quizzes[i] = self.problem.inject_noise( quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise ) - quiz_mask_loss[i] = self.make_ar_mask( + quiz_mask_loss[i] = self.make_quiz_mask( quizzes=quizzes[i], struct=struct, mask=mask_loss ) @@ -206,14 +206,14 @@ class QuizMachine: ###################################################################### - def make_ar_mask(self, quizzes, struct, mask): + def make_quiz_mask(self, quizzes, struct, mask): assert struct in [s for s, _, _, _ in self.train_structures] - return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) + return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) ###################################################################### def predict(self, model, quizzes, struct, mask): - ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) + ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) seq_logproba = torch.empty(quizzes.size(0), device=self.device) @@ -374,7 +374,7 @@ class QuizMachine: seq_logproba.split(self.batch_size), ): input = input.to(device) - quiz_mask_loss = self.make_ar_mask( + quiz_mask_loss = self.make_quiz_mask( input, struct=struct, mask=mask_loss ) output = model(mygpt.BracketedSequence(input)).x @@ -410,7 +410,7 @@ class QuizMachine: self.autoregression( model=model_for_generation, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, s, m), + ar_mask=self.make_quiz_mask(c_quizzes, s, m), seq_logproba=seq_logproba, ) -- 2.39.5