X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=generate.py;h=12a2cbc37087d71575423c156e8c2a91953f27e0;hp=7f29683c7a843699d32a74c00da0717f447de88c;hb=HEAD;hpb=756b8e59a361755f88917aaf4bc214c5f6cce94b diff --git a/generate.py b/generate.py index 7f29683..12a2cbc 100755 --- a/generate.py +++ b/generate.py @@ -49,7 +49,12 @@ parser = argparse.ArgumentParser( parser.add_argument('--nb_samples', type = int, default = 1000, - help='How many samples to generate') + help='How many samples to generate in total') + +parser.add_argument('--batch_size', + type = int, + default = 1000, + help='How many samples to generate at once') parser.add_argument('--problem', type = int, @@ -72,14 +77,12 @@ if os.path.isdir(args.data_dir): else: raise FileNotFoundError('Cannot find ' + args.data_dir) -labels = torch.LongTensor(args.nb_samples).zero_() -labels.narrow(0, 0, labels.size(0)//2).fill_(1) -x = svrt.generate_vignettes(args.problem, labels).float() - -x.sub_(128).div_(64) - -print('MEAN', x.mean(), 'STD', x.std()) - -for k in range(x.size(0)): - filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k) - torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename) +for n in range(0, args.nb_samples, args.batch_size): + print(n, '/', args.nb_samples) + labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_() + labels.narrow(0, 0, labels.size(0)//2).fill_(1) + x = svrt.generate_vignettes(args.problem, labels).float() + x.sub_(128).div_(64) + for k in range(x.size(0)): + filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n) + torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)