X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=ad73f0c84a527ad840be80f76f5e3b850c220b54;hb=eb93f3a7b09a43d6404ca23eaf62eee2e96e59b1;hp=79d3ff462d6c6d6ea7e2d3114568d2d6ba6c5c09;hpb=44363bdf89bf78a62776129c4a5f97ad6a360293;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 79d3ff4..ad73f0c 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -65,7 +65,7 @@ args = parser.parse_args() log_file = open(args.log_file, 'w') -print('Logging into ' + args.log_file) +print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) def log_string(s): s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \ @@ -112,7 +112,7 @@ def train_model(train_input, train_target): model.cuda() criterion.cuda() - optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100 + optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100 for k in range(0, args.nb_epochs): acc_loss = 0.0 @@ -144,9 +144,7 @@ def nb_errors(model, data_input, data_target, bs = 100): ###################################################################### -# for problem_number in range(1, 24): - -for problem_number in [ 3 ]: +for problem_number in range(1, 24): train_input, train_target = generate_set(problem_number, args.nb_train_samples) test_input, test_target = generate_set(problem_number, args.nb_test_samples)