projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
e81aa33
)
Now generate samples by batches.
author
Francois Fleuret
<francois@fleuret.org>
Tue, 9 Jan 2018 16:20:25 +0000
(17:20 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Tue, 9 Jan 2018 16:20:25 +0000
(17:20 +0100)
generate.py
patch
|
blob
|
history
diff --git
a/generate.py
b/generate.py
index
7e40fb4
..
12a2cbc
100755
(executable)
--- a/
generate.py
+++ b/
generate.py
@@
-49,7
+49,12
@@
parser = argparse.ArgumentParser(
parser.add_argument('--nb_samples',
type = int,
default = 1000,
parser.add_argument('--nb_samples',
type = int,
default = 1000,
- help='How many samples to generate')
+ help='How many samples to generate in total')
+
+parser.add_argument('--batch_size',
+ type = int,
+ default = 1000,
+ help='How many samples to generate at once')
parser.add_argument('--problem',
type = int,
parser.add_argument('--problem',
type = int,
@@
-72,11
+77,9
@@
if os.path.isdir(args.data_dir):
else:
raise FileNotFoundError('Cannot find ' + args.data_dir)
else:
raise FileNotFoundError('Cannot find ' + args.data_dir)
-batch_size = 100
-
-for n in range(0, args.nb_samples, batch_size):
+for n in range(0, args.nb_samples, args.batch_size):
print(n, '/', args.nb_samples)
print(n, '/', args.nb_samples)
- labels = torch.LongTensor(min(batch_size, args.nb_samples - n)).zero_()
+ labels = torch.LongTensor(min(
args.
batch_size, args.nb_samples - n)).zero_()
labels.narrow(0, 0, labels.size(0)//2).fill_(1)
x = svrt.generate_vignettes(args.problem, labels).float()
x.sub_(128).div_(64)
labels.narrow(0, 0, labels.size(0)//2).fill_(1)
x = svrt.generate_vignettes(args.problem, labels).float()
x.sub_(128).div_(64)