- def generate(self, primer, model, nb_tokens):
- t_primer = primer.strip().split(' ')
- t_generated = [ ]
-
- for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
- input = torch.tensor(t, device = self.device)
- input = F.pad(input, (0, 1)) # Add the next token, the one to predict
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t_next = dist.sample()
- else:
- t_next = logits.argmax()
- t_generated.append(self.id2token[t_next.item()])
-
- return ' '.join(t_primer + t_generated)
+ def generate(self, primer_descr, model, nb_tokens):
+ results = autoregression(
+ model, self.batch_size,
+ nb_samples = 1, nb_tokens = nb_tokens, primer = descr2tensor(primer_descr),
+ device = self.device
+ )
+ return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])