From 3d87b29b369400cb64811a4cdd152f62f50a0931 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 10:04:27 +0200 Subject: [PATCH] Update. --- main.py | 12 ++++++------ quiz_machine.py | 12 +++++------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 9faa7bd..6f543a0 100755 --- a/main.py +++ b/main.py @@ -365,14 +365,9 @@ def run_tests(model, quiz_machine, local_device=main_device): for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): input = input.to(local_device) - - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - + output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) @@ -387,6 +382,9 @@ def run_tests(model, quiz_machine, local_device=main_device): ) +###################################################################### + + def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() @@ -467,6 +465,8 @@ c_quizzes_procedure = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] +###################################################################### + def save_additional_results(models, science_w_quizzes): for model in models: diff --git a/quiz_machine.py b/quiz_machine.py index 134bf21..9ca84b3 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -184,6 +184,8 @@ class QuizMachine: assert struct in self.train_struct return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) + ###################################################################### + def predict(self, model, quizzes, struct, mask): ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) @@ -202,13 +204,9 @@ class QuizMachine: return result, correct - def produce_results( - self, - n_epoch, - model, - input, - result_dir, - ): + ###################################################################### + + def produce_results(self, n_epoch, model, input, result_dir): input = input.to(self.device) result = input.new(input.size()) correct = input.new(input.size(0)) -- 2.20.1