Update.
[pysvrt.git] / cnn-svrt.py
index f3d350e..d6c7169 100755 (executable)
@@ -25,12 +25,14 @@ import time
 import argparse
 import math
 import distutils.util
+import re
 
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 from torch import FloatTensor as Tensor
@@ -72,6 +74,9 @@ parser.add_argument('--batch_size',
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
@@ -94,13 +99,15 @@ args = parser.parse_args()
 
 log_file = open(args.log_file, 'a')
 pred_log_t = None
+last_tag_t = time.time()
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 # Log and prints the string, with a time stamp. Does not log the
 # remark
+
 def log_string(s, remark = ''):
-    global pred_log_t
+    global pred_log_t, last_tag_t
 
     t = time.time()
 
@@ -111,10 +118,14 @@ def log_string(s, remark = ''):
 
     pred_log_t = t
 
-    log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
 
-    print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
 
 ######################################################################
 
@@ -288,6 +299,21 @@ class vignette_logger():
             )
             self.last_t = t
 
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
+
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
@@ -350,6 +376,10 @@ for problem_number in map(int, args.problems.split(',')):
             train_set.nb_samples / (time.time() - t))
         )
 
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
         if args.validation_error_threshold > 0.0:
             validation_set = VignetteSet(problem_number,
                                          args.nb_validation_samples, args.batch_size,