Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jul 2024 15:52:40 +0000 (17:52 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jul 2024 15:52:40 +0000 (17:52 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 4cf4d59..a7338c7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -18,6 +18,8 @@ import sky, grids, quiz_machine
 
 import threading
 
+import torch.multiprocessing as mp
+
 # world quizzes vs. culture quizzes
 
 ######################################################################
index ae14614..8ab5696 100755 (executable)
@@ -424,17 +424,23 @@ class QuizMachine:
         )
 
         for model in models:
-            for input, l in zip(
-                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
-            ):
-                input = input.to(self.device)
-                ar_mask = self.make_ar_mask(input)
-                output = model(mygpt.BracketedSequence(input)).x
-                ce = (
-                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
-                    * ar_mask
-                )
-                l[:, model.id] = -ce.sum(dim=-1)
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+
+                for input, l in zip(
+                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                ):
+                    input = input.to(self.device)
+                    ar_mask = self.make_ar_mask(input)
+                    output = model(mygpt.BracketedSequence(input)).x
+                    ce = (
+                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                        * ar_mask
+                    )
+                    l[:, model.id] = -ce.sum(dim=-1)
+
+                model.train(t)
 
         return logproba.to("cpu")