######################################################################
def autoregression(
- model,
+ model, batch_size,
nb_samples, nb_tokens_to_generate, starting_input = None,
device = torch.device('cpu')
):
first = starting_input.size(1)
results = torch.cat((starting_input, results), 1)
- for input in results.split(args.batch_size):
+ for input in results.split(batch_size):
for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
output = model(input)
logits = output[:, s]
return 256
def produce_results(self, n_epoch, model, nb_samples = 64):
- results = autoregression(model, nb_samples, 28 * 28, device = self.device)
+ results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
image_name = f'result_mnist_{n_epoch:04d}.png'
torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
image_name, nrow = 16, pad_value = 0.8)