Now generate samples by batches.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 16:20:25 +0000 (17:20 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 16:20:25 +0000 (17:20 +0100)
generate.py

index 7e40fb4..12a2cbc 100755 (executable)
@@ -49,7 +49,12 @@ parser = argparse.ArgumentParser(
 parser.add_argument('--nb_samples',
                     type = int,
                     default = 1000,
-                    help='How many samples to generate')
+                    help='How many samples to generate in total')
+
+parser.add_argument('--batch_size',
+                    type = int,
+                    default = 1000,
+                    help='How many samples to generate at once')
 
 parser.add_argument('--problem',
                     type = int,
@@ -72,11 +77,9 @@ if os.path.isdir(args.data_dir):
 else:
     raise FileNotFoundError('Cannot find ' + args.data_dir)
 
-batch_size = 100
-
-for n in range(0, args.nb_samples, batch_size):
+for n in range(0, args.nb_samples, args.batch_size):
     print(n, '/', args.nb_samples)
-    labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_()
+    labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_()
     labels.narrow(0, 0, labels.size(0)//2).fill_(1)
     x = svrt.generate_vignettes(args.problem, labels).float()
     x.sub_(128).div_(64)