projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
abbbb61
)
Trying to make multiprocessing and cuda friends with each other.
author
Francois Fleuret
<francois@fleuret.org>
Fri, 16 Jun 2017 08:39:11 +0000
(10:39 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Fri, 16 Jun 2017 08:39:11 +0000
(10:39 +0200)
vignette_set.py
patch
|
blob
|
history
diff --git
a/vignette_set.py
b/vignette_set.py
index
c46beea
..
19a6f33
100755
(executable)
--- a/
vignette_set.py
+++ b/
vignette_set.py
@@
-22,7
+22,7
@@
import torch
from math import sqrt
import torch
from math import sqrt
-from multiprocessing import Pool, cpu_count
+from
torch.
multiprocessing import Pool, cpu_count
from torch import Tensor
from torch.autograd import Variable
from torch import Tensor
from torch.autograd import Variable
@@
-32,14
+32,11
@@
import svrt
######################################################################
def generate_one_batch(s):
######################################################################
def generate_one_batch(s):
- problem_number, batch_size,
cuda,
random_seed = s
+ problem_number, batch_size, random_seed = s
svrt.seed(random_seed)
target = torch.LongTensor(batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
svrt.seed(random_seed)
target = torch.LongTensor(batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
- if cuda:
- input = input.cuda()
- target = target.cuda()
return [ input, target ]
class VignetteSet:
return [ input, target ]
class VignetteSet:
@@
-54,7
+51,7
@@
class VignetteSet:
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
for b in range(0, self.nb_batches):
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
for b in range(0, self.nb_batches):
- mp_args.append( [ problem_number, batch_size,
cuda,
seeds[b] ])
+ mp_args.append( [ problem_number, batch_size, seeds[b] ])
# self.data = []
# for b in range(0, self.nb_batches):
# self.data = []
# for b in range(0, self.nb_batches):
@@
-73,6
+70,9
@@
class VignetteSet:
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
+ if cuda:
+ self.data[b][0] = self.data[b][0].cuda()
+ self.data[b][1] = self.data[b][1].cuda()
def get_batch(self, b):
return self.data[b]
def get_batch(self, b):
return self.data[b]