+def autoregression(
+ model, batch_size,
+ nb_samples, nb_tokens_to_generate, primer = None,
+ device = torch.device('cpu')
+):
+ results = torch.zeros(
+ nb_samples, nb_tokens_to_generate,
+ dtype = torch.int64, device = device
+ )
+
+ if primer is None:
+ first = 0
+ else:
+ first = primer.size(1)
+ results = torch.cat((primer, results), 1)
+
+ for input in results.split(batch_size):
+ for s in range(first, input.size(1)):
+ output = model(input)
+ logits = output[:, s]
+ if args.synthesis_sampling:
+ dist = torch.distributions.categorical.Categorical(logits = logits)
+ t_next = dist.sample()
+ else:
+ t_next = logits.argmax(1)
+ input[:, s] = t_next
+
+ return results
+
+######################################################################
+