projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
3feef90
)
Cleaning up.
author
Francois Fleuret
<francois@fleuret.org>
Sat, 17 Jun 2017 18:55:53 +0000
(20:55 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sat, 17 Jun 2017 18:55:53 +0000
(20:55 +0200)
cnn-svrt.py
patch
|
blob
|
history
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
5dc91c8
..
153bdc9
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-40,7
+40,7
@@
from torchvision import datasets, transforms, utils
# SVRT
# SVRT
-
from vignette_set import VignetteSet, CompressedVignetteS
et
+
import vignette_s
et
######################################################################
######################################################################
@@
-268,17
+268,21
@@
if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_
print('The number of samples must be a multiple of the batch size.')
raise
print('The number of samples must be a multiple of the batch size.')
raise
+if args.compress_vignettes:
+ VignetteSet = vignette_set.CompressedVignetteSet
+else:
+ VignetteSet = vignette_set.VignetteSet
+
for problem_number in range(1, 24):
for problem_number in range(1, 24):
- log_string('
**** problem ' + str(problem_number) + ' ****
')
+ log_string('
############### problem ' + str(problem_number) + ' ###############
')
if args.deep_model:
model = AfrozeDeepNet()
else:
model = AfrozeShallowNet()
if args.deep_model:
model = AfrozeDeepNet()
else:
model = AfrozeShallowNet()
- if torch.cuda.is_available():
- model.cuda()
+ if torch.cuda.is_available(): model.cuda()
model_filename = model.name + '_' + \
str(problem_number) + '_' + \
model_filename = model.name + '_' + \
str(problem_number) + '_' + \
@@
-288,6
+292,9
@@
for problem_number in range(1, 24):
for p in model.parameters(): nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
for p in model.parameters(): nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
+ ##################################################
+ # Tries to load the model
+
need_to_train = False
try:
model.load_state_dict(torch.load(model_filename))
need_to_train = False
try:
model.load_state_dict(torch.load(model_filename))
@@
-295,20
+302,18
@@
for problem_number in range(1, 24):
except:
need_to_train = True
except:
need_to_train = True
+ ##################################################
+ # Train if necessary
+
if need_to_train:
log_string('training_model ' + model_filename)
t = time.time()
if need_to_train:
log_string('training_model ' + model_filename)
t = time.time()
- if args.compress_vignettes:
- train_set = CompressedVignetteSet(problem_number,
- args.nb_train_samples, args.batch_size,
- cuda = torch.cuda.is_available())
- else:
- train_set = VignetteSet(problem_number,
- args.nb_train_samples, args.batch_size,
- cuda = torch.cuda.is_available())
+ train_set = VignetteSet(problem_number,
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
log_string('data_generation {:0.2f} samples / s'.format(
train_set.nb_samples / (time.time() - t))
log_string('data_generation {:0.2f} samples / s'.format(
train_set.nb_samples / (time.time() - t))
@@
-327,18
+332,16
@@
for problem_number in range(1, 24):
train_set.nb_samples)
)
train_set.nb_samples)
)
+ ##################################################
+ # Test if necessary
+
if need_to_train or args.test_loaded_models:
t = time.time()
if need_to_train or args.test_loaded_models:
t = time.time()
- if args.compress_vignettes:
- test_set = CompressedVignetteSet(problem_number,
- args.nb_test_samples, args.batch_size,
- cuda = torch.cuda.is_available())
- else:
- test_set = VignetteSet(problem_number,
- args.nb_test_samples, args.batch_size,
- cuda = torch.cuda.is_available())
+ test_set = VignetteSet(problem_number,
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
log_string('data_generation {:0.2f} samples / s'.format(
test_set.nb_samples / (time.time() - t))
log_string('data_generation {:0.2f} samples / s'.format(
test_set.nb_samples / (time.time() - t))