Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 05:15:15 +0000 (07:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 24 Jun 2024 05:15:15 +0000 (07:15 +0200)
mygpt.py

index 7117e76..a178491 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -292,7 +292,7 @@ class MyGPT(nn.Module):
             )  # Needed to initialize the model's cache
         for s in range(to_generate.min(), to_generate.max() + 1):
             output = self(BracketedSequence(input, s, 1)).x
             )  # Needed to initialize the model's cache
         for s in range(to_generate.min(), to_generate.max() + 1):
             output = self(BracketedSequence(input, s, 1)).x
-            logits = output[:, s]
+            logits = output[:, s] / temperature
             if forbidden_tokens is not None:
                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
             if forced_biases is not None:
             if forbidden_tokens is not None:
                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
             if forced_biases is not None: