Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 08:04:27 +0000 (10:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 08:04:27 +0000 (10:04 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 9faa7bd..6f543a0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -365,14 +365,9 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
             input = input.to(local_device)
-
-            bs = model(mygpt.BracketedSequence(input))
-            output = bs.x
-
+            output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
-
             acc_test_loss += loss.item() * input.size(0)
-
             nb_test_samples += input.size(0)
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
@@ -387,6 +382,9 @@ def run_tests(model, quiz_machine, local_device=main_device):
         )
 
 
+######################################################################
+
+
 def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
 
@@ -467,6 +465,8 @@ c_quizzes_procedure = [
     (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
 ]
 
+######################################################################
+
 
 def save_additional_results(models, science_w_quizzes):
     for model in models:
index 134bf21..9ca84b3 100755 (executable)
@@ -184,6 +184,8 @@ class QuizMachine:
         assert struct in self.train_struct
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
+    ######################################################################
+
     def predict(self, model, quizzes, struct, mask):
         ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
@@ -202,13 +204,9 @@ class QuizMachine:
 
         return result, correct
 
-    def produce_results(
-        self,
-        n_epoch,
-        model,
-        input,
-        result_dir,
-    ):
+    ######################################################################
+
+    def produce_results(self, n_epoch, model, input, result_dir):
         input = input.to(self.device)
         result = input.new(input.size())
         correct = input.new(input.size(0))