From e5da1f7362063a680687b24b2460ce775c374833 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 9 Jan 2018 17:20:25 +0100 Subject: [PATCH] Now generate samples by batches. --- generate.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/generate.py b/generate.py index 7e40fb4..12a2cbc 100755 --- a/generate.py +++ b/generate.py @@ -49,7 +49,12 @@ parser = argparse.ArgumentParser( parser.add_argument('--nb_samples', type = int, default = 1000, - help='How many samples to generate') + help='How many samples to generate in total') + +parser.add_argument('--batch_size', + type = int, + default = 1000, + help='How many samples to generate at once') parser.add_argument('--problem', type = int, @@ -72,11 +77,9 @@ if os.path.isdir(args.data_dir): else: raise FileNotFoundError('Cannot find ' + args.data_dir) -batch_size = 100 - -for n in range(0, args.nb_samples, batch_size): +for n in range(0, args.nb_samples, args.batch_size): print(n, '/', args.nb_samples) - labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_() + labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_() labels.narrow(0, 0, labels.size(0)//2).fill_(1) x = svrt.generate_vignettes(args.problem, labels).float() x.sub_(128).div_(64) -- 2.39.5