From a2bf298d8b609810bbb4caaabf6633deee768481 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 23 Mar 2021 17:22:49 +0100 Subject: [PATCH] Initial commit. --- gpt-test.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100755 gpt-test.py diff --git a/gpt-test.py b/gpt-test.py new file mode 100755 index 0000000..ff72e50 --- /dev/null +++ b/gpt-test.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ +# +# Written by Francois Fleuret + +# +# You need to install PyTorch +# +# https://pytorch.org/get-started/locally/ +# +# and Huggingface's transformers (which include pre-trained GPT +# models) +# +# pip install transformers +# + +import torch + +from transformers import GPT2Tokenizer, GPT2LMHeadModel + +###################################################################### + +def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None): + nt, ns = 0, 0 + tokens = tokenizer.encode(primer) + primer_len = len(tokens) + while True: + outputs = model(torch.tensor([tokens])).logits + if temperature is None: + next_token = torch.argmax(outputs[0, -1]) + else: + dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature) + next_token = dist.sample((1,)).item() + + tokens.append(next_token) + nt += 1 + if tokenizer.decode([next_token]) == '.': ns += 1 + if ns == nb_sentences or nt == nb_token_max: + return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \ + tokenizer.decode(tokens[primer_len:]) + +###################################################################### + +#model_name = 'gpt2' +#model_name = 'gpt2-large' +model_name = 'gpt2-xl' + +tokenizer = GPT2Tokenizer.from_pretrained(model_name) +model = GPT2LMHeadModel.from_pretrained(model_name) +model.eval() + +print( + complete(model, + 'The object was blue all over, but also green all over, it was a', + ) +) + +###################################################################### -- 2.39.5