Minor fixes + changed the default number of epochs to 100.
[pysvrt.git] / cnn-svrt.py
index c1fe3ac..96fb498 100755 (executable)
@@ -22,6 +22,9 @@
 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
 
 import time
+import argparse
+
+from colorama import Fore, Back, Style
 
 import torch
 
@@ -36,28 +39,82 @@ import svrt
 
 ######################################################################
 
+parser = argparse.ArgumentParser(
+    description = 'Simple convnet test on the SVRT.',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--nb_train_samples',
+                    type = int, default = 100000,
+                    help = 'How many samples for train')
+
+parser.add_argument('--nb_test_samples',
+                    type = int, default = 10000,
+                    help = 'How many samples for test')
+
+parser.add_argument('--nb_epochs',
+                    type = int, default = 100,
+                    help = 'How many training epochs')
+
+parser.add_argument('--log_file',
+                    type = str, default = 'cnn-svrt.log',
+                    help = 'Log file name')
+
+args = parser.parse_args()
+
+######################################################################
+
+log_file = open(args.log_file, 'w')
+
+print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
+
+def log_string(s):
+    s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
+    log_file.write(s + '\n')
+    log_file.flush()
+    print(s)
+
+######################################################################
+
 def generate_set(p, n):
     target = torch.LongTensor(n).bernoulli_(0.5)
+    t = time.time()
     input = svrt.generate_vignettes(p, target)
+    t = time.time() - t
+    log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
     return Variable(input), Variable(target)
 
 ######################################################################
 
-# 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
+# Afroze's ShallowNet
+
+#                       map size   nb. maps
+#                     ----------------------
+#    input                128x128    1
+# -- conv(21x21 x 6)   -> 108x108    6
+# -- max(2x2)          -> 54x54      6
+# -- conv(19x19 x 16)  -> 36x36      16
+# -- max(2x2)          -> 18x18      16
+# -- conv(18x18 x 120) -> 1x1        120
+# -- reshape           -> 120        1
+# -- full(120x84)      -> 84         1
+# -- full(84x2)        -> 2          1
 
 class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
-        self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
-        self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
-        self.fc1 = nn.Linear(500, 100)
-        self.fc2 = nn.Linear(100, 2)
+        self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
+        self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
+        self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
+        self.fc1 = nn.Linear(120, 84)
+        self.fc2 = nn.Linear(84, 2)
 
     def forward(self, x):
-        x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
-        x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
-        x = x.view(-1, 500)
+        x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
+        x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
+        x = fn.relu(self.conv3(x))
+        x = x.view(-1, 120)
         x = fn.relu(self.fc1(x))
         x = self.fc2(x)
         return x
@@ -65,54 +122,53 @@ class Net(nn.Module):
 def train_model(train_input, train_target):
     model, criterion = Net(), nn.CrossEntropyLoss()
 
+    nb_parameters = 0
+    for p in model.parameters():
+        nb_parameters += p.numel()
+    log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
+
     if torch.cuda.is_available():
         model.cuda()
         criterion.cuda()
 
-    nb_epochs = 25
-    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
+    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
 
-    for k in range(0, nb_epochs):
-        for b in range(0, nb_train_samples, bs):
+    for k in range(0, args.nb_epochs):
+        acc_loss = 0.0
+        for b in range(0, train_input.size(0), bs):
             output = model.forward(train_input.narrow(0, b, bs))
             loss = criterion(output, train_target.narrow(0, b, bs))
+            acc_loss = acc_loss + loss.data[0]
             model.zero_grad()
             loss.backward()
             optimizer.step()
+        log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
 
     return model
 
 ######################################################################
 
-def print_test_error(model, test_input, test_target):
-    bs = 100
-    nb_test_errors = 0
+def nb_errors(model, data_input, data_target, bs = 100):
+    ne = 0
 
-    for b in range(0, nb_test_samples, bs):
-        output = model.forward(test_input.narrow(0, b, bs))
+    for b in range(0, data_input.size(0), bs):
+        output = model.forward(data_input.narrow(0, b, bs))
         wta_prediction = output.data.max(1)[1].view(-1)
 
         for i in range(0, bs):
-            if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
-                nb_test_errors = nb_test_errors + 1
+            if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
+                ne = ne + 1
 
-    print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
-        100 * nb_test_errors / nb_test_samples,
-        nb_test_errors,
-        nb_test_samples)
-    )
+    return ne
 
 ######################################################################
 
-nb_train_samples = 100000
-nb_test_samples = 10000
+for arg in vars(args):
+    log_string('ARGUMENT ' + str(arg) + ' ' + str(getattr(args, arg)))
 
-for p in range(1, 24):
-    print('-- PROBLEM #{:d} --'.format(p))
-
-    t1 = time.time()
-    train_input, train_target = generate_set(p, nb_train_samples)
-    test_input, test_target = generate_set(p, nb_test_samples)
+for problem_number in range(1, 24):
+    train_input, train_target = generate_set(problem_number, args.nb_train_samples)
+    test_input, test_target = generate_set(problem_number, args.nb_test_samples)
 
     if torch.cuda.is_available():
         train_input, train_target = train_input.cuda(), train_target.cuda()
@@ -122,17 +178,24 @@ for p in range(1, 24):
     train_input.data.sub_(mu).div_(std)
     test_input.data.sub_(mu).div_(std)
 
-    t2 = time.time()
-    print('[data generation {:.02f}s]'.format(t2 - t1))
     model = train_model(train_input, train_target)
 
-    t3 = time.time()
-    print('[train {:.02f}s]'.format(t3 - t2))
-    print_test_error(model, test_input, test_target)
+    nb_train_errors = nb_errors(model, train_input, train_target)
+
+    log_string('TRAIN_ERROR {:d} {:.02f}% {:d} {:d}'.format(
+        problem_number,
+        100 * nb_train_errors / train_input.size(0),
+        nb_train_errors,
+        train_input.size(0))
+    )
 
-    t4 = time.time()
+    nb_test_errors = nb_errors(model, test_input, test_target)
 
-    print('[test {:.02f}s]'.format(t4 - t3))
-    print()
+    log_string('TEST_ERROR {:d} {:.02f}% {:d} {:d}'.format(
+        problem_number,
+        100 * nb_test_errors / test_input.size(0),
+        nb_test_errors,
+        test_input.size(0))
+    )
 
 ######################################################################