Update.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 12:26:54 +0000 (14:26 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 12:26:54 +0000 (14:26 +0200)
cnn-svrt.py

index 755d1c7..79d3ff4 100755 (executable)
@@ -23,6 +23,7 @@
 
 import time
 import argparse
+from colorama import Fore, Back, Style
 
 import torch
 
@@ -54,14 +55,21 @@ parser.add_argument('--nb_epochs',
                     type = int, default = 25,
                     help = 'How many training epochs')
 
+parser.add_argument('--log_file',
+                    type = str, default = 'cnn-svrt.log',
+                    help = 'Log file name')
+
 args = parser.parse_args()
 
 ######################################################################
 
-log_file = open('cnn-svrt.log', 'w')
+log_file = open(args.log_file, 'w')
+
+print('Logging into ' + args.log_file)
 
 def log_string(s):
-    s = time.ctime() + ' ' + str(problem_number) + ' | ' + s
+    s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
+        str(problem_number) + ' ' + s
     log_file.write(s + '\n')
     log_file.flush()
     print(s)
@@ -70,7 +78,10 @@ def log_string(s):
 
 def generate_set(p, n):
     target = torch.LongTensor(n).bernoulli_(0.5)
+    t = time.time()
     input = svrt.generate_vignettes(p, target)
+    t = time.time() - t
+    log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
     return Variable(input), Variable(target)
 
@@ -101,7 +112,7 @@ def train_model(train_input, train_target):
         model.cuda()
         criterion.cuda()
 
-    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
+    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
 
     for k in range(0, args.nb_epochs):
         acc_loss = 0.0
@@ -133,7 +144,9 @@ def nb_errors(model, data_input, data_target, bs = 100):
 
 ######################################################################
 
-for problem_number in range(1, 24):
+# for problem_number in range(1, 24):
+
+for problem_number in [ 3 ]:
     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
 
@@ -147,9 +160,17 @@ for problem_number in range(1, 24):
 
     model = train_model(train_input, train_target)
 
+    nb_train_errors = nb_errors(model, train_input, train_target)
+
+    log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
+        100 * nb_train_errors / train_input.size(0),
+        nb_train_errors,
+        train_input.size(0))
+    )
+
     nb_test_errors = nb_errors(model, test_input, test_target)
 
-    log_string('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
+    log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
         100 * nb_test_errors / test_input.size(0),
         nb_test_errors,
         test_input.size(0))