Minor update.
[pysvrt.git] / generate.py
index 7f29683..12a2cbc 100755 (executable)
@@ -49,7 +49,12 @@ parser = argparse.ArgumentParser(
 parser.add_argument('--nb_samples',
                     type = int,
                     default = 1000,
-                    help='How many samples to generate')
+                    help='How many samples to generate in total')
+
+parser.add_argument('--batch_size',
+                    type = int,
+                    default = 1000,
+                    help='How many samples to generate at once')
 
 parser.add_argument('--problem',
                     type = int,
@@ -72,14 +77,12 @@ if os.path.isdir(args.data_dir):
 else:
     raise FileNotFoundError('Cannot find ' + args.data_dir)
 
-labels = torch.LongTensor(args.nb_samples).zero_()
-labels.narrow(0, 0, labels.size(0)//2).fill_(1)
-x = svrt.generate_vignettes(args.problem, labels).float()
-
-x.sub_(128).div_(64)
-
-print('MEAN', x.mean(), 'STD', x.std())
-
-for k in range(x.size(0)):
-    filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k)
-    torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)
+for n in range(0, args.nb_samples, args.batch_size):
+    print(n, '/', args.nb_samples)
+    labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_()
+    labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+    x = svrt.generate_vignettes(args.problem, labels).float()
+    x.sub_(128).div_(64)
+    for k in range(x.size(0)):
+        filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n)
+        torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)