projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Cosmetics.
[pysvrt.git]
/
vignette_set.py
diff --git
a/vignette_set.py
b/vignette_set.py
index
72880ba
..
5062f3e
100755
(executable)
--- a/
vignette_set.py
+++ b/
vignette_set.py
@@
-22,7
+22,7
@@
import torch
from math import sqrt
import torch
from math import sqrt
-from
multiprocessing import Pool, cpu_count
+from
torch import multiprocessing
from torch import Tensor
from torch.autograd import Variable
from torch import Tensor
from torch.autograd import Variable
@@
-32,31
+32,40
@@
import svrt
######################################################################
def generate_one_batch(s):
######################################################################
def generate_one_batch(s):
- svrt.seed(s)
- target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
+ problem_number, batch_size, random_seed = s
+ svrt.seed(random_seed)
+ target = torch.LongTensor(batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
- if self.cuda:
- input = input.cuda()
- target = target.cuda()
return [ input, target ]
class VignetteSet:
return [ input, target ]
class VignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_
batches
+ self.nb_batches = nb_
samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.nb_samples = self.nb_batches * self.batch_size
- seed_list = torch.LongTensor(self.nb_batches).random_().tolist()
+ seeds = torch.LongTensor(self.nb_batches).random_()
+ mp_args = []
+ for b in range(0, self.nb_batches):
+ mp_args.append( [ problem_number, batch_size, seeds[b] ])
-
#
self.data = []
-
#
for b in range(0, self.nb_batches):
-
# self.data.append(generate_one_batch(seed_list
[b]))
+ self.data = []
+ for b in range(0, self.nb_batches):
+
self.data.append(generate_one_batch(mp_args
[b]))
- self.data = Pool(cpu_count()).map(generate_one_batch, seed_list)
+ # Weird thing going on with the multi-processing, waiting for more info
+
+ # pool = multiprocessing.Pool(multiprocessing.cpu_count())
+ # self.data = pool.map(generate_one_batch, mp_args)
acc = 0.0
acc_sq = 0.0
acc = 0.0
acc_sq = 0.0
@@
-69,6
+78,9
@@
class VignetteSet:
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
+ if cuda:
+ self.data[b][0] = self.data[b][0].cuda()
+ self.data[b][1] = self.data[b][1].cuda()
def get_batch(self, b):
return self.data[b]
def get_batch(self, b):
return self.data[b]
@@
-76,11
+88,16
@@
class VignetteSet:
######################################################################
class CompressedVignetteSet:
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_
batches
+ self.nb_batches = nb_
samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []