X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=ad73f0c84a527ad840be80f76f5e3b850c220b54;hp=35c664fabddf30e1ee1fff7dd694c793cd6ebaf2;hb=eb93f3a7b09a43d6404ca23eaf62eee2e96e59b1;hpb=7aa372bc9cfa44a245b8048eb2216f024ed365e0 diff --git a/cnn-svrt.py b/cnn-svrt.py index 35c664f..ad73f0c 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -65,7 +65,7 @@ args = parser.parse_args() log_file = open(args.log_file, 'w') -print('Logging into ' + args.log_file) +print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) def log_string(s): s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \ @@ -112,7 +112,7 @@ def train_model(train_input, train_target): model.cuda() criterion.cuda() - optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100 + optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100 for k in range(0, args.nb_epochs): acc_loss = 0.0