X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=e5ecf768bfdc060b255c6612649f3f2ede62a51a;hb=9eec6d457d017e0204cc80c0e1b24f894d064267;hp=26e7de843bccfd7db4d761369b1ac66e35798319;hpb=e754d1075d8d0a5949e71f426ab07ce73be6099e;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 26e7de8..e5ecf76 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -77,11 +77,15 @@ parser.add_argument('--test_loaded_models', type = distutils.util.strtobool, default = 'False', help = 'Should we compute the test errors of loaded models') +parser.add_argument('--problems', + type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23', + help = 'What problems to process') + args = parser.parse_args() ###################################################################### -log_file = open(args.log_file, 'w') +log_file = open(args.log_file, 'a') pred_log_t = None print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) @@ -250,17 +254,18 @@ def int_to_suffix(n): class vignette_logger(): def __init__(self, delay_min = 60): self.start_t = time.time() + self.last_t = self.start_t self.delay_min = delay_min def __call__(self, n, m): t = time.time() - if t > self.start_t + self.delay_min: + if t > self.last_t + self.delay_min: dt = (t - self.start_t) / m log_string('sample_generation {:d} / {:d}'.format( m, n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']' ) - self.start_t = t + self.last_t = t ###################################################################### @@ -268,6 +273,8 @@ if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_ print('The number of samples must be a multiple of the batch size.') raise +log_string('############### start ###############') + if args.compress_vignettes: log_string('using_compressed_vignettes') VignetteSet = svrtset.CompressedVignetteSet @@ -275,7 +282,7 @@ else: log_string('using_uncompressed_vignettes') VignetteSet = svrtset.VignetteSet -for problem_number in range(1, 24): +for problem_number in map(int, args.problems.split(',')): log_string('############### problem ' + str(problem_number) + ' ###############')