X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=0c63b85f184470bfed0e9d8daf7bb1e395234bda;hb=24789240ca5395a857c16e602a2d0f5e8cb176d8;hp=7dc6dfff1d6ee98df36ab92624c4ad901f3acc29;hpb=9ee9a775ccd2391990b3ab226e73c86bd19bd36a;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 7dc6dff..0c63b85 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -29,7 +29,10 @@ import distutils.util import re import signal -from colorama import Fore, Back, Style +try: + from colorama import Fore, Back, Style +except ImportError: + Fore, Back, Style = '', '', '' # Pytorch @@ -540,7 +543,10 @@ for problem_number in map(int, args.problems.split(',')): else: validation_set = None - train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done) + train_model(model, model_filename, + train_set, validation_set, + nb_epochs_done = nb_epochs_done) + log_string('saved_model ' + model_filename) nb_train_errors = nb_errors(model, train_set)