X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=b2adf986cb57c9bd447571be06d62379ad97d756;hb=8ea0e3c5cc303718a8b508b656f7aa9e64ea3070;hp=65922040a78fc60ada7bbf75c5373ace576ee4c6;hpb=c92c5abfa01b78a292929d2363bb05798c2af39f;p=mygpt.git diff --git a/main.py b/main.py index 6592204..b2adf98 100755 --- a/main.py +++ b/main.py @@ -111,7 +111,7 @@ for n in vars(args): ###################################################################### def autoregression( - model, + model, batch_size, nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): @@ -126,7 +126,7 @@ def autoregression( first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(args.batch_size): + for input in results.split(batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] @@ -386,7 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = autoregression(model, nb_samples, 28 * 28, device = self.device) + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8)