--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+#
+# Written by Francois Fleuret <francois@fleuret.org>
+
+#
+# You need to install PyTorch
+#
+# https://pytorch.org/get-started/locally/
+#
+# and Huggingface's transformers (which include pre-trained GPT
+# models)
+#
+# pip install transformers
+#
+
+import torch
+
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
+
+######################################################################
+
+def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
+ nt, ns = 0, 0
+ tokens = tokenizer.encode(primer)
+ primer_len = len(tokens)
+ while True:
+ outputs = model(torch.tensor([tokens])).logits
+ if temperature is None:
+ next_token = torch.argmax(outputs[0, -1])
+ else:
+ dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
+ next_token = dist.sample((1,)).item()
+
+ tokens.append(next_token)
+ nt += 1
+ if tokenizer.decode([next_token]) == '.': ns += 1
+ if ns == nb_sentences or nt == nb_token_max:
+ return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
+ tokenizer.decode(tokens[primer_len:])
+
+######################################################################
+
+#model_name = 'gpt2'
+#model_name = 'gpt2-large'
+model_name = 'gpt2-xl'
+
+tokenizer = GPT2Tokenizer.from_pretrained(model_name)
+model = GPT2LMHeadModel.from_pretrained(model_name)
+model.eval()
+
+print(
+ complete(model,
+ 'The object was blue all over, but also green all over, it was a',
+ )
+)
+
+######################################################################