From: Francois Fleuret Date: Tue, 9 Jan 2018 16:19:32 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e81aa33e97c95d97d7c24fbcfae220e97a9b2887;p=pysvrt.git Update. --- diff --git a/generate.py b/generate.py index 7f29683..7e40fb4 100755 --- a/generate.py +++ b/generate.py @@ -72,14 +72,14 @@ if os.path.isdir(args.data_dir): else: raise FileNotFoundError('Cannot find ' + args.data_dir) -labels = torch.LongTensor(args.nb_samples).zero_() -labels.narrow(0, 0, labels.size(0)//2).fill_(1) -x = svrt.generate_vignettes(args.problem, labels).float() - -x.sub_(128).div_(64) - -print('MEAN', x.mean(), 'STD', x.std()) - -for k in range(x.size(0)): - filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k) - torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename) +batch_size = 100 + +for n in range(0, args.nb_samples, batch_size): + print(n, '/', args.nb_samples) + labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_() + labels.narrow(0, 0, labels.size(0)//2).fill_(1) + x = svrt.generate_vignettes(args.problem, labels).float() + x.sub_(128).div_(64) + for k in range(x.size(0)): + filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n) + torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)