Update.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 16:19:32 +0000 (17:19 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 16:19:32 +0000 (17:19 +0100)
generate.py

index 7f29683..7e40fb4 100755 (executable)
@@ -72,14 +72,14 @@ if os.path.isdir(args.data_dir):
 else:
     raise FileNotFoundError('Cannot find ' + args.data_dir)
 
-labels = torch.LongTensor(args.nb_samples).zero_()
-labels.narrow(0, 0, labels.size(0)//2).fill_(1)
-x = svrt.generate_vignettes(args.problem, labels).float()
-
-x.sub_(128).div_(64)
-
-print('MEAN', x.mean(), 'STD', x.std())
-
-for k in range(x.size(0)):
-    filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k)
-    torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)
+batch_size = 100
+
+for n in range(0, args.nb_samples, batch_size):
+    print(n, '/', args.nb_samples)
+    labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_()
+    labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+    x = svrt.generate_vignettes(args.problem, labels).float()
+    x.sub_(128).div_(64)
+    for k in range(x.size(0)):
+        filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n)
+        torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)