X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=ad73f0c84a527ad840be80f76f5e3b850c220b54;hb=eb93f3a7b09a43d6404ca23eaf62eee2e96e59b1;hp=d5685f426f518233d710598ea5f1ece4c1e7ce68;hpb=664435944d9750efb805d9a2035f1d4f4c238a25;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index d5685f4..ad73f0c 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -1,6 +1,29 @@ -#!/usr/bin/env python-for-pytorch +#!/usr/bin/env python + +# svrt is the ``Synthetic Visual Reasoning Test'', an image +# generator for evaluating classification performance of machine +# learning systems, humans and primates. +# +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ +# Written by Francois Fleuret +# +# This file is part of svrt. +# +# svrt is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# svrt is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with selector. If not, see . import time +import argparse +from colorama import Fore, Back, Style import torch @@ -11,14 +34,54 @@ from torch import nn from torch.nn import functional as fn from torchvision import datasets, transforms, utils -from _ext import svrt +import svrt + +###################################################################### + +parser = argparse.ArgumentParser( + description = 'Simple convnet test on the SVRT.', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--nb_train_samples', + type = int, default = 100000, + help = 'How many samples for train') + +parser.add_argument('--nb_test_samples', + type = int, default = 10000, + help = 'How many samples for test') + +parser.add_argument('--nb_epochs', + type = int, default = 25, + help = 'How many training epochs') + +parser.add_argument('--log_file', + type = str, default = 'cnn-svrt.log', + help = 'Log file name') + +args = parser.parse_args() + +###################################################################### + +log_file = open(args.log_file, 'w') + +print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) + +def log_string(s): + s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \ + str(problem_number) + ' ' + s + log_file.write(s + '\n') + log_file.flush() + print(s) ###################################################################### -# The data def generate_set(p, n): target = torch.LongTensor(n).bernoulli_(0.5) + t = time.time() input = svrt.generate_vignettes(p, target) + t = time.time() - t + log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t)) input = input.view(input.size(0), 1, input.size(1), input.size(2)).float() return Variable(input), Variable(target) @@ -49,50 +112,42 @@ def train_model(train_input, train_target): model.cuda() criterion.cuda() - nb_epochs = 25 - optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100 + optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100 - for k in range(0, nb_epochs): - for b in range(0, nb_train_samples, bs): + for k in range(0, args.nb_epochs): + acc_loss = 0.0 + for b in range(0, train_input.size(0), bs): output = model.forward(train_input.narrow(0, b, bs)) loss = criterion(output, train_target.narrow(0, b, bs)) + acc_loss = acc_loss + loss.data[0] model.zero_grad() loss.backward() optimizer.step() + log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss)) return model ###################################################################### -def print_test_error(model, test_input, test_target): - bs = 100 - nb_test_errors = 0 +def nb_errors(model, data_input, data_target, bs = 100): + ne = 0 - for b in range(0, nb_test_samples, bs): - output = model.forward(test_input.narrow(0, b, bs)) - _, wta = torch.max(output.data, 1) + for b in range(0, data_input.size(0), bs): + output = model.forward(data_input.narrow(0, b, bs)) + wta_prediction = output.data.max(1)[1].view(-1) for i in range(0, bs): - if wta[i][0] != test_target.narrow(0, b, bs).data[i]: - nb_test_errors = nb_test_errors + 1 + if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]: + ne = ne + 1 - print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format( - 100 * nb_test_errors / nb_test_samples, - nb_test_errors, - nb_test_samples) - ) + return ne ###################################################################### -nb_train_samples = 100000 -nb_test_samples = 10000 +for problem_number in range(1, 24): + train_input, train_target = generate_set(problem_number, args.nb_train_samples) + test_input, test_target = generate_set(problem_number, args.nb_test_samples) -for p in range(1, 24): - print('-- PROBLEM #{:d} --'.format(p)) - - t1 = time.time() - train_input, train_target = generate_set(p, nb_train_samples) - test_input, test_target = generate_set(p, nb_test_samples) if torch.cuda.is_available(): train_input, train_target = train_input.cuda(), train_target.cuda() test_input, test_target = test_input.cuda(), test_target.cuda() @@ -101,17 +156,22 @@ for p in range(1, 24): train_input.data.sub_(mu).div_(std) test_input.data.sub_(mu).div_(std) - t2 = time.time() - print('[data generation {:.02f}s]'.format(t2 - t1)) model = train_model(train_input, train_target) - t3 = time.time() - print('[train {:.02f}s]'.format(t3 - t2)) - print_test_error(model, test_input, test_target) + nb_train_errors = nb_errors(model, train_input, train_target) + + log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format( + 100 * nb_train_errors / train_input.size(0), + nb_train_errors, + train_input.size(0)) + ) - t4 = time.time() + nb_test_errors = nb_errors(model, test_input, test_target) - print('[test {:.02f}s]'.format(t4 - t3)) - print() + log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format( + 100 * nb_test_errors / test_input.size(0), + nb_test_errors, + test_input.size(0)) + ) ######################################################################