Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 10:13:01 +0000 (12:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 10:13:01 +0000 (12:13 +0200)
mygpt.py
tasks.py

index a178491..c58bea1 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -292,11 +292,16 @@ class MyGPT(nn.Module):
             )  # Needed to initialize the model's cache
         for s in range(to_generate.min(), to_generate.max() + 1):
             output = self(BracketedSequence(input, s, 1)).x
-            logits = output[:, s] / temperature
+            logits = output[:, s]
+
+            logits = logits.log_softmax(dim=-1) / temperature
+
             if forbidden_tokens is not None:
                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+
             if forced_biases is not None:
                 logits = logits + forced_biases[None, :]
+
             if deterministic_synthesis:
                 t_next = logits.argmax(1)
             else:
index b967465..5edb472 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -274,6 +274,10 @@ class World(Task):
             device=self.device,
         )
 
+        # Should not be necessary though, the autoregression is done
+        # in eval mode
+        sum_logits = sum_logits.detach()
+
         average_logits = sum_logits / quizzes.numel()
 
         # It's a bit brutal to do it twice, we should probably have a