From: Francois Fleuret Date: Tue, 9 Jan 2018 16:20:25 +0000 (+0100) Subject: Now generate samples by batches. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e5da1f7362063a680687b24b2460ce775c374833;p=pysvrt.git Now generate samples by batches. --- diff --git a/generate.py b/generate.py index 7e40fb4..12a2cbc 100755 --- a/generate.py +++ b/generate.py @@ -49,7 +49,12 @@ parser = argparse.ArgumentParser( parser.add_argument('--nb_samples', type = int, default = 1000, - help='How many samples to generate') + help='How many samples to generate in total') + +parser.add_argument('--batch_size', + type = int, + default = 1000, + help='How many samples to generate at once') parser.add_argument('--problem', type = int, @@ -72,11 +77,9 @@ if os.path.isdir(args.data_dir): else: raise FileNotFoundError('Cannot find ' + args.data_dir) -batch_size = 100 - -for n in range(0, args.nb_samples, batch_size): +for n in range(0, args.nb_samples, args.batch_size): print(n, '/', args.nb_samples) - labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_() + labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_() labels.narrow(0, 0, labels.size(0)//2).fill_(1) x = svrt.generate_vignettes(args.problem, labels).float() x.sub_(128).div_(64)