Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 23 Mar 2021 16:22:49 +0000 (17:22 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 23 Mar 2021 16:22:49 +0000 (17:22 +0100)
gpt-test.py [new file with mode: 0755]

diff --git a/gpt-test.py b/gpt-test.py
new file mode 100755 (executable)
index 0000000..ff72e50
--- /dev/null
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+#
+# Written by Francois Fleuret <francois@fleuret.org>
+
+#
+# You need to install PyTorch
+#
+#   https://pytorch.org/get-started/locally/
+#
+# and Huggingface's transformers (which include pre-trained GPT
+# models)
+#
+#  pip install transformers
+#
+
+import torch
+
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
+
+######################################################################
+
+def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
+    nt, ns = 0, 0
+    tokens = tokenizer.encode(primer)
+    primer_len = len(tokens)
+    while True:
+        outputs = model(torch.tensor([tokens])).logits
+        if temperature is None:
+            next_token = torch.argmax(outputs[0, -1])
+        else:
+            dist =  torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
+            next_token = dist.sample((1,)).item()
+
+        tokens.append(next_token)
+        nt += 1
+        if tokenizer.decode([next_token]) == '.': ns += 1
+        if ns == nb_sentences or nt == nb_token_max:
+            return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
+                tokenizer.decode(tokens[primer_len:])
+
+######################################################################
+
+#model_name = 'gpt2'
+#model_name = 'gpt2-large'
+model_name = 'gpt2-xl'
+
+tokenizer = GPT2Tokenizer.from_pretrained(model_name)
+model = GPT2LMHeadModel.from_pretrained(model_name)
+model.eval()
+
+print(
+    complete(model,
+             'The object was blue all over, but also green all over, it was a',
+    )
+)
+
+######################################################################