- results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
- for input in results.split(self.batch_size):
- for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+ results = torch.zeros(
+ nb_samples, nb_tokens_to_generate,
+ dtype = torch.int64, device = device
+ )
+
+ if starting_input is None:
+ first = 0
+ else:
+ first = starting_input.size(1)
+ results = torch.cat((starting_input, results), 1)
+
+ for input in results.split(args.batch_size):
+ for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):