######################################################################
-def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
+
+def complete(
+ model, tokenizer, primer, nb_sentences=1, nb_token_max=100, temperature=None
+):
nt, ns = 0, 0
tokens = tokenizer.encode(primer)
primer_len = len(tokens)
if temperature is None:
next_token = torch.argmax(outputs[0, -1])
else:
- dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
+ dist = torch.distributions.Categorical(logits=outputs[0, -1] / temperature)
next_token = dist.sample((1,)).item()
tokens.append(next_token)
nt += 1
- if tokenizer.decode([next_token]) == '.': ns += 1
+ if tokenizer.decode([next_token]) == ".":
+ ns += 1
if ns == nb_sentences or nt == nb_token_max:
- return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
- tokenizer.decode(tokens[primer_len:])
+ return (
+ "<"
+ + tokenizer.decode(tokens[:primer_len])
+ + ">"
+ + tokenizer.decode(tokens[primer_len:])
+ )
+
######################################################################
-#model_name = 'gpt2'
-#model_name = 'gpt2-large'
-model_name = 'gpt2-xl'
+# model_name = 'gpt2'
+# model_name = 'gpt2-large'
+model_name = "gpt2-xl"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
print(
- complete(model,
- 'The object was blue all over, but also green all over, it was a',
+ f"Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)"
+)
+
+print(
+ complete(
+ model,
+ tokenizer,
+ "The object was blue all over, but also green all over, it was a",
)
)