projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
9f566b0
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Thu, 22 Jun 2017 06:05:25 +0000
(08:05 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 22 Jun 2017 06:05:25 +0000
(08:05 +0200)
cnn-svrt.py
patch
|
blob
|
history
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
a41d42c
..
d6c7169
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-32,6
+32,7
@@
from colorama import Fore, Back, Style
# Pytorch
import torch
# Pytorch
import torch
+import torchvision
from torch import optim
from torch import FloatTensor as Tensor
from torch import optim
from torch import FloatTensor as Tensor
@@
-73,6
+74,9
@@
parser.add_argument('--batch_size',
parser.add_argument('--log_file',
type = str, default = 'default.log')
parser.add_argument('--log_file',
type = str, default = 'default.log')
+parser.add_argument('--nb_exemplar_vignettes',
+ type = int, default = -1)
+
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
@@
-295,6
+299,21
@@
class vignette_logger():
)
self.last_t = t
)
self.last_t = t
+def save_examplar_vignettes(data_set, nb, name):
+ n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+ for k in range(0, nb):
+ b = n[k] // data_set.batch_size
+ m = n[k] % data_set.batch_size
+ i, t = data_set.get_batch(b)
+ i = i[m].float()
+ i.sub_(i.min())
+ i.div_(i.max())
+ if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+ patchwork[k].copy_(i)
+
+ torchvision.utils.save_image(patchwork, name)
+
######################################################################
if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
######################################################################
if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
@@
-357,6
+376,10
@@
for problem_number in map(int, args.problems.split(',')):
train_set.nb_samples / (time.time() - t))
)
train_set.nb_samples / (time.time() - t))
)
+ if args.nb_exemplar_vignettes > 0:
+ save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+ 'examplar_{:d}.png'.format(problem_number))
+
if args.validation_error_threshold > 0.0:
validation_set = VignetteSet(problem_number,
args.nb_validation_samples, args.batch_size,
if args.validation_error_threshold > 0.0:
validation_set = VignetteSet(problem_number,
args.nb_validation_samples, args.batch_size,