projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pysvrt.git]
/
cnn-svrt.py
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
142b81f
..
7fe2db2
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-25,18
+25,22
@@
import time
import argparse
import math
import distutils.util
import argparse
import math
import distutils.util
+import re
from colorama import Fore, Back, Style
# Pytorch
import torch
from colorama import Fore, Back, Style
# Pytorch
import torch
+import torchvision
from torch import optim
from torch import optim
+from torch import multiprocessing
from torch import FloatTensor as Tensor
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as fn
from torch import FloatTensor as Tensor
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as fn
+
from torchvision import datasets, transforms, utils
# SVRT
from torchvision import datasets, transforms, utils
# SVRT
@@
-56,6
+60,13
@@
parser.add_argument('--nb_train_samples',
parser.add_argument('--nb_test_samples',
type = int, default = 10000)
parser.add_argument('--nb_test_samples',
type = int, default = 10000)
+parser.add_argument('--nb_validation_samples',
+ type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+ type = float, default = 0.0,
+ help = 'Early training termination criterion')
+
parser.add_argument('--nb_epochs',
type = int, default = 50)
parser.add_argument('--nb_epochs',
type = int, default = 50)
@@
-65,6
+76,9
@@
parser.add_argument('--batch_size',
parser.add_argument('--log_file',
type = str, default = 'default.log')
parser.add_argument('--log_file',
type = str, default = 'default.log')
+parser.add_argument('--nb_exemplar_vignettes',
+ type = int, default = -1)
+
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
@@
-77,19
+91,25
@@
parser.add_argument('--test_loaded_models',
type = distutils.util.strtobool, default = 'False',
help = 'Should we compute the test errors of loaded models')
type = distutils.util.strtobool, default = 'False',
help = 'Should we compute the test errors of loaded models')
+parser.add_argument('--problems',
+ type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
+ help = 'What problems to process')
+
args = parser.parse_args()
######################################################################
args = parser.parse_args()
######################################################################
-log_file = open(args.log_file, '
w
')
+log_file = open(args.log_file, '
a
')
pred_log_t = None
pred_log_t = None
+last_tag_t = time.time()
print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
# Log and prints the string, with a time stamp. Does not log the
# remark
print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
# Log and prints the string, with a time stamp. Does not log the
# remark
+
def log_string(s, remark = ''):
def log_string(s, remark = ''):
- global pred_log_t
+ global pred_log_t
, last_tag_t
t = time.time()
t = time.time()
@@
-100,10
+120,14
@@
def log_string(s, remark = ''):
pred_log_t = t
pred_log_t = t
- log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
+ if t > last_tag_t + 3600:
+ last_tag_t = t
+ print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+ log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
log_file.flush()
log_file.flush()
- print(Fore.BLUE +
'[' + time.ctime() + ']
' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
+ print(Fore.BLUE +
time.ctime() + '
' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
######################################################################
######################################################################
@@
-190,7
+214,22
@@
class AfrozeDeepNet(nn.Module):
######################################################################
######################################################################
-def train_model(model, train_set):
+def nb_errors(model, data_set):
+ ne = 0
+ for b in range(0, data_set.nb_batches):
+ input, target = data_set.get_batch(b)
+ output = model.forward(Variable(input))
+ wta_prediction = output.data.max(1)[1].view(-1)
+
+ for i in range(0, data_set.batch_size):
+ if wta_prediction[i] != target[i]:
+ ne = ne + 1
+
+ return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
@@
-212,25
+251,24
@@
def train_model(model, train_set):
loss.backward()
optimizer.step()
dt = (time.time() - start_t) / (e + 1)
loss.backward()
optimizer.step()
dt = (time.time() - start_t) / (e + 1)
+
log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
- return model
-
-######################################################################
+ if validation_set is not None:
+ nb_validation_errors = nb_errors(model, validation_set)
-def nb_errors(model, data_set):
- ne = 0
- for b in range(0, data_set.nb_batches):
- input, target = data_set.get_batch(b)
- output = model.forward(Variable(input))
- wta_prediction = output.data.max(1)[1].view(-1)
+ log_string('validation_error {:.02f}% {:d} {:d}'.format(
+ 100 * nb_validation_errors / validation_set.nb_samples,
+ nb_validation_errors,
+ validation_set.nb_samples)
+ )
-
for i in range(0, data_set.batch_size)
:
- if wta_prediction[i] != target[i]:
- ne = ne + 1
+
if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold
:
+ log_string('below validation_error_threshold')
+ break
- return
ne
+ return
model
######################################################################
######################################################################
@@
-250,16
+288,33
@@
def int_to_suffix(n):
class vignette_logger():
def __init__(self, delay_min = 60):
self.start_t = time.time()
class vignette_logger():
def __init__(self, delay_min = 60):
self.start_t = time.time()
+ self.last_t = self.start_t
self.delay_min = delay_min
def __call__(self, n, m):
t = time.time()
self.delay_min = delay_min
def __call__(self, n, m):
t = time.time()
- if t > self.
star
t_t + self.delay_min:
+ if t > self.
las
t_t + self.delay_min:
dt = (t - self.start_t) / m
log_string('sample_generation {:d} / {:d}'.format(
m,
n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
)
dt = (t - self.start_t) / m
log_string('sample_generation {:d} / {:d}'.format(
m,
n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
)
+ self.last_t = t
+
+def save_examplar_vignettes(data_set, nb, name):
+ n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+ for k in range(0, nb):
+ b = n[k] // data_set.batch_size
+ m = n[k] % data_set.batch_size
+ i, t = data_set.get_batch(b)
+ i = i[m].float()
+ i.sub_(i.min())
+ i.div_(i.max())
+ if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+ patchwork[k].copy_(i)
+
+ torchvision.utils.save_image(patchwork, name)
######################################################################
######################################################################
@@
-267,6
+322,8
@@
if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_
print('The number of samples must be a multiple of the batch size.')
raise
print('The number of samples must be a multiple of the batch size.')
raise
+log_string('############### start ###############')
+
if args.compress_vignettes:
log_string('using_compressed_vignettes')
VignetteSet = svrtset.CompressedVignetteSet
if args.compress_vignettes:
log_string('using_compressed_vignettes')
VignetteSet = svrtset.CompressedVignetteSet
@@
-274,7
+331,7
@@
else:
log_string('using_uncompressed_vignettes')
VignetteSet = svrtset.VignetteSet
log_string('using_uncompressed_vignettes')
VignetteSet = svrtset.VignetteSet
-for problem_number in
range(1, 24
):
+for problem_number in
map(int, args.problems.split(',')
):
log_string('############### problem ' + str(problem_number) + ' ###############')
log_string('############### problem ' + str(problem_number) + ' ###############')
@@
-321,7
+378,19
@@
for problem_number in range(1, 24):
train_set.nb_samples / (time.time() - t))
)
train_set.nb_samples / (time.time() - t))
)
- train_model(model, train_set)
+ if args.nb_exemplar_vignettes > 0:
+ save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+ 'examplar_{:d}.png'.format(problem_number))
+
+ if args.validation_error_threshold > 0.0:
+ validation_set = VignetteSet(problem_number,
+ args.nb_validation_samples, args.batch_size,
+ cuda = torch.cuda.is_available(),
+ logger = vignette_logger())
+ else:
+ validation_set = None
+
+ train_model(model, train_set, validation_set)
torch.save(model.state_dict(), model_filename)
log_string('saved_model ' + model_filename)
torch.save(model.state_dict(), model_filename)
log_string('saved_model ' + model_filename)
@@
-345,10
+414,6
@@
for problem_number in range(1, 24):
args.nb_test_samples, args.batch_size,
cuda = torch.cuda.is_available())
args.nb_test_samples, args.batch_size,
cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(
- test_set.nb_samples / (time.time() - t))
- )
-
nb_test_errors = nb_errors(model, test_set)
log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
nb_test_errors = nb_errors(model, test_set)
log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(