Update.
[pysvrt.git] / cnn-svrt.py
index e5ecf76..7fe2db2 100755 (executable)
@@ -25,18 +25,22 @@ import time
 import argparse
 import math
 import distutils.util
 import argparse
 import math
 import distutils.util
+import re
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 # SVRT
 from torchvision import datasets, transforms, utils
 
 # SVRT
@@ -56,6 +60,13 @@ parser.add_argument('--nb_train_samples',
 parser.add_argument('--nb_test_samples',
                     type = int, default = 10000)
 
 parser.add_argument('--nb_test_samples',
                     type = int, default = 10000)
 
+parser.add_argument('--nb_validation_samples',
+                    type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+                    type = float, default = 0.0,
+                    help = 'Early training termination criterion')
+
 parser.add_argument('--nb_epochs',
                     type = int, default = 50)
 
 parser.add_argument('--nb_epochs',
                     type = int, default = 50)
 
@@ -65,6 +76,9 @@ parser.add_argument('--batch_size',
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
@@ -87,13 +101,15 @@ args = parser.parse_args()
 
 log_file = open(args.log_file, 'a')
 pred_log_t = None
 
 log_file = open(args.log_file, 'a')
 pred_log_t = None
+last_tag_t = time.time()
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
+
 def log_string(s, remark = ''):
 def log_string(s, remark = ''):
-    global pred_log_t
+    global pred_log_t, last_tag_t
 
     t = time.time()
 
 
     t = time.time()
 
@@ -104,10 +120,14 @@ def log_string(s, remark = ''):
 
     pred_log_t = t
 
 
     pred_log_t = t
 
-    log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
 
     log_file.flush()
 
-    print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -194,7 +214,22 @@ class AfrozeDeepNet(nn.Module):
 
 ######################################################################
 
 
 ######################################################################
 
-def train_model(model, train_set):
+def nb_errors(model, data_set):
+    ne = 0
+    for b in range(0, data_set.nb_batches):
+        input, target = data_set.get_batch(b)
+        output = model.forward(Variable(input))
+        wta_prediction = output.data.max(1)[1].view(-1)
+
+        for i in range(0, data_set.batch_size):
+            if wta_prediction[i] != target[i]:
+                ne = ne + 1
+
+    return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
@@ -216,25 +251,24 @@ def train_model(model, train_set):
             loss.backward()
             optimizer.step()
         dt = (time.time() - start_t) / (e + 1)
             loss.backward()
             optimizer.step()
         dt = (time.time() - start_t) / (e + 1)
+
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
-    return model
+        if validation_set is not None:
+            nb_validation_errors = nb_errors(model, validation_set)
 
 
-######################################################################
+            log_string('validation_error {:.02f}% {:d} {:d}'.format(
+                100 * nb_validation_errors / validation_set.nb_samples,
+                nb_validation_errors,
+                validation_set.nb_samples)
+            )
 
 
-def nb_errors(model, data_set):
-    ne = 0
-    for b in range(0, data_set.nb_batches):
-        input, target = data_set.get_batch(b)
-        output = model.forward(Variable(input))
-        wta_prediction = output.data.max(1)[1].view(-1)
+            if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
+                log_string('below validation_error_threshold')
+                break
 
 
-        for i in range(0, data_set.batch_size):
-            if wta_prediction[i] != target[i]:
-                ne = ne + 1
-
-    return ne
+    return model
 
 ######################################################################
 
 
 ######################################################################
 
@@ -267,6 +301,21 @@ class vignette_logger():
             )
             self.last_t = t
 
             )
             self.last_t = t
 
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
+
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
@@ -329,7 +378,19 @@ for problem_number in map(int, args.problems.split(',')):
             train_set.nb_samples / (time.time() - t))
         )
 
             train_set.nb_samples / (time.time() - t))
         )
 
-        train_model(model, train_set)
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
+        if args.validation_error_threshold > 0.0:
+            validation_set = VignetteSet(problem_number,
+                                         args.nb_validation_samples, args.batch_size,
+                                         cuda = torch.cuda.is_available(),
+                                         logger = vignette_logger())
+        else:
+            validation_set = None
+
+        train_model(model, train_set, validation_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
 
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
 
@@ -353,10 +414,6 @@ for problem_number in map(int, args.problems.split(',')):
                                args.nb_test_samples, args.batch_size,
                                cuda = torch.cuda.is_available())
 
                                args.nb_test_samples, args.batch_size,
                                cuda = torch.cuda.is_available())
 
-        log_string('data_generation {:0.2f} samples / s'.format(
-            test_set.nb_samples / (time.time() - t))
-        )
-
         nb_test_errors = nb_errors(model, test_set)
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
         nb_test_errors = nb_errors(model, test_set)
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(