projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
The learning continues from where it stopped if a state is found.
[pysvrt.git]
/
cnn-svrt.py
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
1511e82
..
fb1cad9
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-279,7
+279,7
@@
def nb_errors(model, data_set):
######################################################################
######################################################################
-def train_model(model,
train_set, validation_set
):
+def train_model(model,
model_filename, train_set, validation_set, nb_epochs_done = 0
):
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
@@
-290,7
+290,7
@@
def train_model(model, train_set, validation_set):
start_t = time.time()
start_t = time.time()
- for e in range(
0
, args.nb_epochs):
+ for e in range(
nb_epochs_done
, args.nb_epochs):
acc_loss = 0.0
for b in range(0, train_set.nb_batches):
input, target = train_set.get_batch(b)
acc_loss = 0.0
for b in range(0, train_set.nb_batches):
input, target = train_set.get_batch(b)
@@
-305,6
+305,8
@@
def train_model(model, train_set, validation_set):
log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
+ torch.save([ model.state_dict(), e + 1 ], model_filename)
+
if validation_set is not None:
nb_validation_errors = nb_errors(model, validation_set)
if validation_set is not None:
nb_validation_errors = nb_errors(model, validation_set)
@@
-404,7
+406,7
@@
for problem_number in map(int, args.problems.split(',')):
model_filename = model.name + '_pb:' + \
str(problem_number) + '_ns:' + \
model_filename = model.name + '_pb:' + \
str(problem_number) + '_ns:' + \
- int_to_suffix(args.nb_train_samples) + '.
param
'
+ int_to_suffix(args.nb_train_samples) + '.
state
'
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
@@
-415,15
+417,17
@@
for problem_number in map(int, args.problems.split(',')):
need_to_train = False
try:
need_to_train = False
try:
- model.load_state_dict(torch.load(model_filename))
+ model_state_dict, nb_epochs_done = torch.load(model_filename)
+ model.load_state_dict(model_state_dict)
log_string('loaded_model ' + model_filename)
except:
log_string('loaded_model ' + model_filename)
except:
- need_to_train = True
+ nb_epochs_done = 0
+
##################################################
# Train if necessary
##################################################
# Train if necessary
- if n
eed_to_train
:
+ if n
b_epochs_done < args.nb_epochs
:
log_string('training_model ' + model_filename)
log_string('training_model ' + model_filename)
@@
-450,8
+454,7
@@
for problem_number in map(int, args.problems.split(',')):
else:
validation_set = None
else:
validation_set = None
- train_model(model, train_set, validation_set)
- torch.save(model.state_dict(), model_filename)
+ train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
log_string('saved_model ' + model_filename)
nb_train_errors = nb_errors(model, train_set)
log_string('saved_model ' + model_filename)
nb_train_errors = nb_errors(model, train_set)