parser.add_argument('--nb_samples',
type = int,
default = 1000,
- help='How many samples to generate')
+ help='How many samples to generate in total')
+
+parser.add_argument('--batch_size',
+ type = int,
+ default = 1000,
+ help='How many samples to generate at once')
parser.add_argument('--problem',
type = int,
else:
raise FileNotFoundError('Cannot find ' + args.data_dir)
-batch_size = 100
-
-for n in range(0, args.nb_samples, batch_size):
+for n in range(0, args.nb_samples, args.batch_size):
print(n, '/', args.nb_samples)
- labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_()
+ labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_()
labels.narrow(0, 0, labels.size(0)//2).fill_(1)
x = svrt.generate_vignettes(args.problem, labels).float()
x.sub_(128).div_(64)