Update.
[pysvrt.git] / cnn-svrt.py
index 142b81f..7fe2db2 100755 (executable)
@@ -25,18 +25,22 @@ import time
 import argparse
 import math
 import distutils.util
 import argparse
 import math
 import distutils.util
+import re
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 # SVRT
 from torchvision import datasets, transforms, utils
 
 # SVRT
@@ -56,6 +60,13 @@ parser.add_argument('--nb_train_samples',
 parser.add_argument('--nb_test_samples',
                     type = int, default = 10000)
 
 parser.add_argument('--nb_test_samples',
                     type = int, default = 10000)
 
+parser.add_argument('--nb_validation_samples',
+                    type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+                    type = float, default = 0.0,
+                    help = 'Early training termination criterion')
+
 parser.add_argument('--nb_epochs',
                     type = int, default = 50)
 
 parser.add_argument('--nb_epochs',
                     type = int, default = 50)
 
@@ -65,6 +76,9 @@ parser.add_argument('--batch_size',
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
@@ -77,19 +91,25 @@ parser.add_argument('--test_loaded_models',
                     type = distutils.util.strtobool, default = 'False',
                     help = 'Should we compute the test errors of loaded models')
 
                     type = distutils.util.strtobool, default = 'False',
                     help = 'Should we compute the test errors of loaded models')
 
+parser.add_argument('--problems',
+                    type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
+                    help = 'What problems to process')
+
 args = parser.parse_args()
 
 ######################################################################
 
 args = parser.parse_args()
 
 ######################################################################
 
-log_file = open(args.log_file, 'w')
+log_file = open(args.log_file, 'a')
 pred_log_t = None
 pred_log_t = None
+last_tag_t = time.time()
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
+
 def log_string(s, remark = ''):
 def log_string(s, remark = ''):
-    global pred_log_t
+    global pred_log_t, last_tag_t
 
     t = time.time()
 
 
     t = time.time()
 
@@ -100,10 +120,14 @@ def log_string(s, remark = ''):
 
     pred_log_t = t
 
 
     pred_log_t = t
 
-    log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
 
     log_file.flush()
 
-    print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -190,7 +214,22 @@ class AfrozeDeepNet(nn.Module):
 
 ######################################################################
 
 
 ######################################################################
 
-def train_model(model, train_set):
+def nb_errors(model, data_set):
+    ne = 0
+    for b in range(0, data_set.nb_batches):
+        input, target = data_set.get_batch(b)
+        output = model.forward(Variable(input))
+        wta_prediction = output.data.max(1)[1].view(-1)
+
+        for i in range(0, data_set.batch_size):
+            if wta_prediction[i] != target[i]:
+                ne = ne + 1
+
+    return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
@@ -212,25 +251,24 @@ def train_model(model, train_set):
             loss.backward()
             optimizer.step()
         dt = (time.time() - start_t) / (e + 1)
             loss.backward()
             optimizer.step()
         dt = (time.time() - start_t) / (e + 1)
+
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
-    return model
-
-######################################################################
+        if validation_set is not None:
+            nb_validation_errors = nb_errors(model, validation_set)
 
 
-def nb_errors(model, data_set):
-    ne = 0
-    for b in range(0, data_set.nb_batches):
-        input, target = data_set.get_batch(b)
-        output = model.forward(Variable(input))
-        wta_prediction = output.data.max(1)[1].view(-1)
+            log_string('validation_error {:.02f}% {:d} {:d}'.format(
+                100 * nb_validation_errors / validation_set.nb_samples,
+                nb_validation_errors,
+                validation_set.nb_samples)
+            )
 
 
-        for i in range(0, data_set.batch_size):
-            if wta_prediction[i] != target[i]:
-                ne = ne + 1
+            if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
+                log_string('below validation_error_threshold')
+                break
 
 
-    return ne
+    return model
 
 ######################################################################
 
 
 ######################################################################
 
@@ -250,16 +288,33 @@ def int_to_suffix(n):
 class vignette_logger():
     def __init__(self, delay_min = 60):
         self.start_t = time.time()
 class vignette_logger():
     def __init__(self, delay_min = 60):
         self.start_t = time.time()
+        self.last_t = self.start_t
         self.delay_min = delay_min
 
     def __call__(self, n, m):
         t = time.time()
         self.delay_min = delay_min
 
     def __call__(self, n, m):
         t = time.time()
-        if t > self.start_t + self.delay_min:
+        if t > self.last_t + self.delay_min:
             dt = (t - self.start_t) / m
             log_string('sample_generation {:d} / {:d}'.format(
                 m,
                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
             )
             dt = (t - self.start_t) / m
             log_string('sample_generation {:d} / {:d}'.format(
                 m,
                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
             )
+            self.last_t = t
+
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -267,6 +322,8 @@ if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_
     print('The number of samples must be a multiple of the batch size.')
     raise
 
     print('The number of samples must be a multiple of the batch size.')
     raise
 
+log_string('############### start ###############')
+
 if args.compress_vignettes:
     log_string('using_compressed_vignettes')
     VignetteSet = svrtset.CompressedVignetteSet
 if args.compress_vignettes:
     log_string('using_compressed_vignettes')
     VignetteSet = svrtset.CompressedVignetteSet
@@ -274,7 +331,7 @@ else:
     log_string('using_uncompressed_vignettes')
     VignetteSet = svrtset.VignetteSet
 
     log_string('using_uncompressed_vignettes')
     VignetteSet = svrtset.VignetteSet
 
-for problem_number in range(1, 24):
+for problem_number in map(int, args.problems.split(',')):
 
     log_string('############### problem ' + str(problem_number) + ' ###############')
 
 
     log_string('############### problem ' + str(problem_number) + ' ###############')
 
@@ -321,7 +378,19 @@ for problem_number in range(1, 24):
             train_set.nb_samples / (time.time() - t))
         )
 
             train_set.nb_samples / (time.time() - t))
         )
 
-        train_model(model, train_set)
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
+        if args.validation_error_threshold > 0.0:
+            validation_set = VignetteSet(problem_number,
+                                         args.nb_validation_samples, args.batch_size,
+                                         cuda = torch.cuda.is_available(),
+                                         logger = vignette_logger())
+        else:
+            validation_set = None
+
+        train_model(model, train_set, validation_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
 
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
 
@@ -345,10 +414,6 @@ for problem_number in range(1, 24):
                                args.nb_test_samples, args.batch_size,
                                cuda = torch.cuda.is_available())
 
                                args.nb_test_samples, args.batch_size,
                                cuda = torch.cuda.is_available())
 
-        log_string('data_generation {:0.2f} samples / s'.format(
-            test_set.nb_samples / (time.time() - t))
-        )
-
         nb_test_errors = nb_errors(model, test_set)
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
         nb_test_errors = nb_errors(model, test_set)
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(