Update.
[pysvrt.git] / cnn-svrt.py
index f3d350e..7fe2db2 100755 (executable)
@@ -25,18 +25,22 @@ import time
 import argparse
 import math
 import distutils.util
 import argparse
 import math
 import distutils.util
+import re
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 # SVRT
 from torchvision import datasets, transforms, utils
 
 # SVRT
@@ -72,6 +76,9 @@ parser.add_argument('--batch_size',
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
@@ -94,13 +101,15 @@ args = parser.parse_args()
 
 log_file = open(args.log_file, 'a')
 pred_log_t = None
 
 log_file = open(args.log_file, 'a')
 pred_log_t = None
+last_tag_t = time.time()
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
+
 def log_string(s, remark = ''):
 def log_string(s, remark = ''):
-    global pred_log_t
+    global pred_log_t, last_tag_t
 
     t = time.time()
 
 
     t = time.time()
 
@@ -111,10 +120,14 @@ def log_string(s, remark = ''):
 
     pred_log_t = t
 
 
     pred_log_t = t
 
-    log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
 
     log_file.flush()
 
-    print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -288,6 +301,21 @@ class vignette_logger():
             )
             self.last_t = t
 
             )
             self.last_t = t
 
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
+
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
@@ -350,6 +378,10 @@ for problem_number in map(int, args.problems.split(',')):
             train_set.nb_samples / (time.time() - t))
         )
 
             train_set.nb_samples / (time.time() - t))
         )
 
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
         if args.validation_error_threshold > 0.0:
             validation_set = VignetteSet(problem_number,
                                          args.nb_validation_samples, args.batch_size,
         if args.validation_error_threshold > 0.0:
             validation_set = VignetteSet(problem_number,
                                          args.nb_validation_samples, args.batch_size,