else:
raise FileNotFoundError('Cannot find ' + args.data_dir)
-labels = torch.LongTensor(args.nb_samples).zero_()
-labels.narrow(0, 0, labels.size(0)//2).fill_(1)
-x = svrt.generate_vignettes(args.problem, labels).float()
-
-x.sub_(128).div_(64)
-
-print('MEAN', x.mean(), 'STD', x.std())
-
-for k in range(x.size(0)):
- filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k)
- torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)
+batch_size = 100
+
+for n in range(0, args.nb_samples, batch_size):
+ print(n, '/', args.nb_samples)
+ labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_()
+ labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+ x = svrt.generate_vignettes(args.problem, labels).float()
+ x.sub_(128).div_(64)
+ for k in range(x.size(0)):
+ filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n)
+ torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)