projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Bug fix. Ouch, that was bad.
[pysvrt.git]
/
svrtset.py
diff --git
a/svrtset.py
b/svrtset.py
index
cbc71a3
..
8bfdde1
100755
(executable)
--- a/
svrtset.py
+++ b/
svrtset.py
@@
-41,17
+41,18
@@
def generate_one_batch(s):
class VignetteSet:
class VignetteSet:
- def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False
, logger = None
):
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
- self.batch_size = batch_size
self.problem_number = problem_number
self.problem_number = problem_number
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
@@
-61,6
+62,7
@@
class VignetteSet:
self.data = []
for b in range(0, self.nb_batches):
self.data.append(generate_one_batch(mp_args[b]))
self.data = []
for b in range(0, self.nb_batches):
self.data.append(generate_one_batch(mp_args[b]))
+ if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
# Weird thing going on with the multi-processing, waiting for more info
# Weird thing going on with the multi-processing, waiting for more info
@@
-88,17
+90,19
@@
class VignetteSet:
######################################################################
class CompressedVignetteSet:
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False
, logger = None
):
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
- self.batch_size = batch_size
self.problem_number = problem_number
self.problem_number = problem_number
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
+
self.targets = []
self.input_storages = []
self.targets = []
self.input_storages = []
@@
-111,6
+115,7
@@
class CompressedVignetteSet:
acc_sq += input.float().pow(2).sum() / input.numel()
self.targets.append(target)
self.input_storages.append(svrt.compress(input.storage()))
acc_sq += input.float().pow(2).sum() / input.numel()
self.targets.append(target)
self.input_storages.append(svrt.compress(input.storage()))
+ if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
self.mean = acc / self.nb_batches
self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
self.mean = acc / self.nb_batches
self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)