X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=gpt-test.py;h=096704334c2e89e4af94c28afefccf757fc91f3d;hp=ff72e5076c7695584420bdd0a4cbac952daa8a9f;hb=HEAD;hpb=a2bf298d8b609810bbb4caaabf6633deee768481 diff --git a/gpt-test.py b/gpt-test.py index ff72e50..ddd7dcf 100755 --- a/gpt-test.py +++ b/gpt-test.py @@ -22,7 +22,10 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel ###################################################################### -def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None): + +def complete( + model, tokenizer, primer, nb_sentences=1, nb_token_max=100, temperature=None +): nt, ns = 0, 0 tokens = tokenizer.encode(primer) primer_len = len(tokens) @@ -31,29 +34,41 @@ def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = if temperature is None: next_token = torch.argmax(outputs[0, -1]) else: - dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature) + dist = torch.distributions.Categorical(logits=outputs[0, -1] / temperature) next_token = dist.sample((1,)).item() tokens.append(next_token) nt += 1 - if tokenizer.decode([next_token]) == '.': ns += 1 + if tokenizer.decode([next_token]) == ".": + ns += 1 if ns == nb_sentences or nt == nb_token_max: - return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \ - tokenizer.decode(tokens[primer_len:]) + return ( + "<" + + tokenizer.decode(tokens[:primer_len]) + + ">" + + tokenizer.decode(tokens[primer_len:]) + ) + ###################################################################### -#model_name = 'gpt2' -#model_name = 'gpt2-large' -model_name = 'gpt2-xl' +# model_name = 'gpt2' +# model_name = 'gpt2-large' +model_name = "gpt2-xl" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model.eval() print( - complete(model, - 'The object was blue all over, but also green all over, it was a', + f"Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)" +) + +print( + complete( + model, + tokenizer, + "The object was blue all over, but also green all over, it was a", ) )