Update.
[pysvrt.git] / cnn-svrt.py
index 35c664f..7fe2db2 100755 (executable)
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
-#  along with selector.  If not, see <http://www.gnu.org/licenses/>.
+#  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
 
 import time
 import argparse
 
 import time
 import argparse
+import math
+import distutils.util
+import re
+
 from colorama import Fore, Back, Style
 
 from colorama import Fore, Back, Style
 
+# Pytorch
+
 import torch
 import torch
+import torchvision
 
 from torch import optim
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 from torchvision import datasets, transforms, utils
 
-import svrt
+# SVRT
+
+import svrtset
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = 'Simple convnet test on the SVRT.',
+    description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
 parser.add_argument('--nb_train_samples',
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
 parser.add_argument('--nb_train_samples',
-                    type = int, default = 100000,
-                    help = 'How many samples for train')
+                    type = int, default = 100000)
 
 parser.add_argument('--nb_test_samples',
 
 parser.add_argument('--nb_test_samples',
-                    type = int, default = 10000,
-                    help = 'How many samples for test')
+                    type = int, default = 10000)
+
+parser.add_argument('--nb_validation_samples',
+                    type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+                    type = float, default = 0.0,
+                    help = 'Early training termination criterion')
 
 parser.add_argument('--nb_epochs',
 
 parser.add_argument('--nb_epochs',
-                    type = int, default = 25,
-                    help = 'How many training epochs')
+                    type = int, default = 50)
+
+parser.add_argument('--batch_size',
+                    type = int, default = 100)
 
 parser.add_argument('--log_file',
 
 parser.add_argument('--log_file',
-                    type = str, default = 'cnn-svrt.log',
-                    help = 'Log file name')
+                    type = str, default = 'default.log')
+
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
+parser.add_argument('--compress_vignettes',
+                    type = distutils.util.strtobool, default = 'True',
+                    help = 'Use lossless compression to reduce the memory footprint')
+
+parser.add_argument('--deep_model',
+                    type = distutils.util.strtobool, default = 'True',
+                    help = 'Use Afroze\'s Alexnet-like deep model')
+
+parser.add_argument('--test_loaded_models',
+                    type = distutils.util.strtobool, default = 'False',
+                    help = 'Should we compute the test errors of loaded models')
+
+parser.add_argument('--problems',
+                    type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
+                    help = 'What problems to process')
 
 args = parser.parse_args()
 
 ######################################################################
 
 
 args = parser.parse_args()
 
 ######################################################################
 
-log_file = open(args.log_file, 'w')
+log_file = open(args.log_file, 'a')
+pred_log_t = None
+last_tag_t = time.time()
+
+print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
+
+# Log and prints the string, with a time stamp. Does not log the
+# remark
+
+def log_string(s, remark = ''):
+    global pred_log_t, last_tag_t
 
 
-print('Logging into ' + args.log_file)
+    t = time.time()
+
+    if pred_log_t is None:
+        elapsed = 'start'
+    else:
+        elapsed = '+{:.02f}s'.format(t - pred_log_t)
+
+    pred_log_t = t
 
 
-def log_string(s):
-    s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
-        str(problem_number) + ' ' + s
-    log_file.write(s + '\n')
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
     log_file.flush()
-    print(s)
+
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
 
 ######################################################################
 
 
 ######################################################################
 
-def generate_set(p, n):
-    target = torch.LongTensor(n).bernoulli_(0.5)
-    t = time.time()
-    input = svrt.generate_vignettes(p, target)
-    t = time.time() - t
-    log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
-    input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
-    return Variable(input), Variable(target)
+# Afroze's ShallowNet
+
+#                       map size   nb. maps
+#                     ----------------------
+#    input                128x128    1
+# -- conv(21x21 x 6)   -> 108x108    6
+# -- max(2x2)          -> 54x54      6
+# -- conv(19x19 x 16)  -> 36x36      16
+# -- max(2x2)          -> 18x18      16
+# -- conv(18x18 x 120) -> 1x1        120
+# -- reshape           -> 120        1
+# -- full(120x84)      -> 84         1
+# -- full(84x2)        -> 2          1
+
+class AfrozeShallowNet(nn.Module):
+    def __init__(self):
+        super(AfrozeShallowNet, self).__init__()
+        self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
+        self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
+        self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
+        self.fc1 = nn.Linear(120, 84)
+        self.fc2 = nn.Linear(84, 2)
+        self.name = 'shallownet'
+
+    def forward(self, x):
+        x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
+        x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
+        x = fn.relu(self.conv3(x))
+        x = x.view(-1, 120)
+        x = fn.relu(self.fc1(x))
+        x = self.fc2(x)
+        return x
 
 ######################################################################
 
 
 ######################################################################
 
-# 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
+# Afroze's DeepNet
 
 
-class Net(nn.Module):
+class AfrozeDeepNet(nn.Module):
     def __init__(self):
     def __init__(self):
-        super(Net, self).__init__()
-        self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
-        self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
-        self.fc1 = nn.Linear(500, 100)
-        self.fc2 = nn.Linear(100, 2)
+        super(AfrozeDeepNet, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(1536, 256)
+        self.fc2 = nn.Linear(256, 256)
+        self.fc3 = nn.Linear(256, 2)
+        self.name = 'deepnet'
 
     def forward(self, x):
 
     def forward(self, x):
-        x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
-        x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
-        x = x.view(-1, 500)
-        x = fn.relu(self.fc1(x))
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = x.view(-1, 1536)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
         x = self.fc2(x)
         x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
         return x
 
         return x
 
-def train_model(train_input, train_target):
-    model, criterion = Net(), nn.CrossEntropyLoss()
+######################################################################
+
+def nb_errors(model, data_set):
+    ne = 0
+    for b in range(0, data_set.nb_batches):
+        input, target = data_set.get_batch(b)
+        output = model.forward(Variable(input))
+        wta_prediction = output.data.max(1)[1].view(-1)
+
+        for i in range(0, data_set.batch_size):
+            if wta_prediction[i] != target[i]:
+                ne = ne + 1
+
+    return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):
+    batch_size = args.batch_size
+    criterion = nn.CrossEntropyLoss()
 
     if torch.cuda.is_available():
 
     if torch.cuda.is_available():
-        model.cuda()
         criterion.cuda()
 
         criterion.cuda()
 
-    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
+    optimizer = optim.SGD(model.parameters(), lr = 1e-2)
 
 
-    for k in range(0, args.nb_epochs):
+    start_t = time.time()
+
+    for e in range(0, args.nb_epochs):
         acc_loss = 0.0
         acc_loss = 0.0
-        for b in range(0, train_input.size(0), bs):
-            output = model.forward(train_input.narrow(0, b, bs))
-            loss = criterion(output, train_target.narrow(0, b, bs))
+        for b in range(0, train_set.nb_batches):
+            input, target = train_set.get_batch(b)
+            output = model.forward(Variable(input))
+            loss = criterion(output, Variable(target))
             acc_loss = acc_loss + loss.data[0]
             model.zero_grad()
             loss.backward()
             optimizer.step()
             acc_loss = acc_loss + loss.data[0]
             model.zero_grad()
             loss.backward()
             optimizer.step()
-        log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
+        dt = (time.time() - start_t) / (e + 1)
+
+        log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
+                   ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
+
+        if validation_set is not None:
+            nb_validation_errors = nb_errors(model, validation_set)
+
+            log_string('validation_error {:.02f}% {:d} {:d}'.format(
+                100 * nb_validation_errors / validation_set.nb_samples,
+                nb_validation_errors,
+                validation_set.nb_samples)
+            )
+
+            if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
+                log_string('below validation_error_threshold')
+                break
 
     return model
 
 ######################################################################
 
 
     return model
 
 ######################################################################
 
-def nb_errors(model, data_input, data_target, bs = 100):
-    ne = 0
-
-    for b in range(0, data_input.size(0), bs):
-        output = model.forward(data_input.narrow(0, b, bs))
-        wta_prediction = output.data.max(1)[1].view(-1)
+for arg in vars(args):
+    log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
 
 
-        for i in range(0, bs):
-            if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
-                ne = ne + 1
+######################################################################
 
 
-    return ne
+def int_to_suffix(n):
+    if n >= 1000000 and n%1000000 == 0:
+        return str(n//1000000) + 'M'
+    elif n >= 1000 and n%1000 == 0:
+        return str(n//1000) + 'K'
+    else:
+        return str(n)
+
+class vignette_logger():
+    def __init__(self, delay_min = 60):
+        self.start_t = time.time()
+        self.last_t = self.start_t
+        self.delay_min = delay_min
+
+    def __call__(self, n, m):
+        t = time.time()
+        if t > self.last_t + self.delay_min:
+            dt = (t - self.start_t) / m
+            log_string('sample_generation {:d} / {:d}'.format(
+                m,
+                n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
+            )
+            self.last_t = t
+
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
 
 ######################################################################
 
 
 ######################################################################
 
-for problem_number in range(1, 24):
-    train_input, train_target = generate_set(problem_number, args.nb_train_samples)
-    test_input, test_target = generate_set(problem_number, args.nb_test_samples)
+if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
+    print('The number of samples must be a multiple of the batch size.')
+    raise
 
 
-    if torch.cuda.is_available():
-        train_input, train_target = train_input.cuda(), train_target.cuda()
-        test_input, test_target = test_input.cuda(), test_target.cuda()
+log_string('############### start ###############')
+
+if args.compress_vignettes:
+    log_string('using_compressed_vignettes')
+    VignetteSet = svrtset.CompressedVignetteSet
+else:
+    log_string('using_uncompressed_vignettes')
+    VignetteSet = svrtset.VignetteSet
+
+for problem_number in map(int, args.problems.split(',')):
+
+    log_string('############### problem ' + str(problem_number) + ' ###############')
+
+    if args.deep_model:
+        model = AfrozeDeepNet()
+    else:
+        model = AfrozeShallowNet()
+
+    if torch.cuda.is_available(): model.cuda()
+
+    model_filename = model.name + '_pb:' + \
+                     str(problem_number) + '_ns:' + \
+                     int_to_suffix(args.nb_train_samples) + '.param'
+
+    nb_parameters = 0
+    for p in model.parameters(): nb_parameters += p.numel()
+    log_string('nb_parameters {:d}'.format(nb_parameters))
+
+    ##################################################
+    # Tries to load the model
+
+    need_to_train = False
+    try:
+        model.load_state_dict(torch.load(model_filename))
+        log_string('loaded_model ' + model_filename)
+    except:
+        need_to_train = True
+
+    ##################################################
+    # Train if necessary
+
+    if need_to_train:
+
+        log_string('training_model ' + model_filename)
+
+        t = time.time()
+
+        train_set = VignetteSet(problem_number,
+                                args.nb_train_samples, args.batch_size,
+                                cuda = torch.cuda.is_available(),
+                                logger = vignette_logger())
+
+        log_string('data_generation {:0.2f} samples / s'.format(
+            train_set.nb_samples / (time.time() - t))
+        )
+
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
+        if args.validation_error_threshold > 0.0:
+            validation_set = VignetteSet(problem_number,
+                                         args.nb_validation_samples, args.batch_size,
+                                         cuda = torch.cuda.is_available(),
+                                         logger = vignette_logger())
+        else:
+            validation_set = None
+
+        train_model(model, train_set, validation_set)
+        torch.save(model.state_dict(), model_filename)
+        log_string('saved_model ' + model_filename)
+
+        nb_train_errors = nb_errors(model, train_set)
+
+        log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
+            problem_number,
+            100 * nb_train_errors / train_set.nb_samples,
+            nb_train_errors,
+            train_set.nb_samples)
+        )
 
 
-    mu, std = train_input.data.mean(), train_input.data.std()
-    train_input.data.sub_(mu).div_(std)
-    test_input.data.sub_(mu).div_(std)
+    ##################################################
+    # Test if necessary
 
 
-    model = train_model(train_input, train_target)
+    if need_to_train or args.test_loaded_models:
 
 
-    nb_train_errors = nb_errors(model, train_input, train_target)
+        t = time.time()
 
 
-    log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
-        100 * nb_train_errors / train_input.size(0),
-        nb_train_errors,
-        train_input.size(0))
-    )
+        test_set = VignetteSet(problem_number,
+                               args.nb_test_samples, args.batch_size,
+                               cuda = torch.cuda.is_available())
 
 
-    nb_test_errors = nb_errors(model, test_input, test_target)
+        nb_test_errors = nb_errors(model, test_set)
 
 
-    log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
-        100 * nb_test_errors / test_input.size(0),
-        nb_test_errors,
-        test_input.size(0))
-    )
+        log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
+            problem_number,
+            100 * nb_test_errors / test_set.nb_samples,
+            nb_test_errors,
+            test_set.nb_samples)
+        )
 
 ######################################################################
 
 ######################################################################