Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 20 Aug 2024 20:44:49 +0000 (22:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 20 Aug 2024 20:44:49 +0000 (22:44 +0200)
main.py
mygpt.py
quiz_machine.py

diff --git a/main.py b/main.py
index 8908613..78d01ff 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -327,14 +327,6 @@ def optimizer_to(optim, device):
 ######################################################################
 
 
-def mask_ar_to_ranks(mask_ar):
-    a = (mask_ar < 2).long()
-    a = a.cumsum(dim=1) - a
-    b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
-    a[:, 1:] += b
-    return a
-
-
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.to(local_device).eval()
@@ -366,7 +358,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
             targets = input
 
             output = model(
-                mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+                mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
             ).x
             loss_per_token = F.cross_entropy(
                 output.transpose(1, 2), targets, reduction="none"
@@ -427,7 +419,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         targets = input
 
         output = model(
-            mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+            mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
         ).x
 
         loss_per_token = F.cross_entropy(
index c69c899..cd5b580 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -75,6 +75,14 @@ class RandomBypass(nn.Module):
 # resetted when the input bracket starts at t=0
 
 
+def mask_ar_to_ranks(mask_ar):
+    a = (mask_ar < 2).long()
+    a = a.cumsum(dim=1) - a
+    b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
+    a[:, 1:] += b
+    return a
+
+
 class BracketedSequence:
     def __init__(self, x, first=None, nb=None, ranks=None):
         self.x = x
index d209a07..8cec909 100755 (executable)
@@ -47,7 +47,9 @@ def one_batch_masked_inplace_autoregression(
     s = to_generate.min()
 
     for s, u in zip(indices_1[:-1], indices_1[1:]):
-        logits = model(BracketedSequence(input, s, u - s)).x
+        logits = model(
+            BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar))
+        ).x
 
         if deterministic_synthesis:
             t_next = logits.argmax(dim=2)
@@ -228,7 +230,7 @@ class QuizMachine:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i], _ = self.predict(
-                model=model, quizzes=input[i], struct=struct, quad=quad_ar
+                model=model, quizzes=input[i], struct=struct, quad_ar=quad_ar
             )
 
             predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :]