From 8a548630c88957264306db4354e880414b0fa8ef Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 24 Jun 2024 07:15:15 +0200 Subject: [PATCH] Update. --- mygpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mygpt.py b/mygpt.py index 7117e76..a178491 100755 --- a/mygpt.py +++ b/mygpt.py @@ -292,7 +292,7 @@ class MyGPT(nn.Module): ) # Needed to initialize the model's cache for s in range(to_generate.min(), to_generate.max() + 1): output = self(BracketedSequence(input, s, 1)).x - logits = output[:, s] + logits = output[:, s] / temperature if forbidden_tokens is not None: logits = logits.masked_fill(forbidden_tokens, float("-inf")) if forced_biases is not None: -- 2.20.1