3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 # You need to install PyTorch
11 # https://pytorch.org/get-started/locally/
13 # and Huggingface's transformers (which include pre-trained GPT
16 # pip install transformers
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
23 ######################################################################
25 def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
27 tokens = tokenizer.encode(primer)
28 primer_len = len(tokens)
30 outputs = model(torch.tensor([tokens])).logits
31 if temperature is None:
32 next_token = torch.argmax(outputs[0, -1])
34 dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
35 next_token = dist.sample((1,)).item()
37 tokens.append(next_token)
39 if tokenizer.decode([next_token]) == '.': ns += 1
40 if ns == nb_sentences or nt == nb_token_max:
41 return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
42 tokenizer.decode(tokens[primer_len:])
44 ######################################################################
47 #model_name = 'gpt2-large'
48 model_name = 'gpt2-xl'
50 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
51 model = GPT2LMHeadModel.from_pretrained(model_name)
56 'The object was blue all over, but also green all over, it was a',
60 ######################################################################