X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=generate.py;h=12a2cbc37087d71575423c156e8c2a91953f27e0;hp=7e40fb4ba4450bdd482930d0dd813cfecc93ccc5;hb=HEAD;hpb=e81aa33e97c95d97d7c24fbcfae220e97a9b2887 diff --git a/generate.py b/generate.py index 7e40fb4..12a2cbc 100755 --- a/generate.py +++ b/generate.py @@ -49,7 +49,12 @@ parser = argparse.ArgumentParser( parser.add_argument('--nb_samples', type = int, default = 1000, - help='How many samples to generate') + help='How many samples to generate in total') + +parser.add_argument('--batch_size', + type = int, + default = 1000, + help='How many samples to generate at once') parser.add_argument('--problem', type = int, @@ -72,11 +77,9 @@ if os.path.isdir(args.data_dir): else: raise FileNotFoundError('Cannot find ' + args.data_dir) -batch_size = 100 - -for n in range(0, args.nb_samples, batch_size): +for n in range(0, args.nb_samples, args.batch_size): print(n, '/', args.nb_samples) - labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_() + labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_() labels.narrow(0, 0, labels.size(0)//2).fill_(1) x = svrt.generate_vignettes(args.problem, labels).float() x.sub_(128).div_(64)