Replaced SGD with Adam, make the learning rate 1e-1 again.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 from colorama import Fore, Back, Style
27
28 import torch
29
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
33 from torch import nn
34 from torch.nn import functional as fn
35 from torchvision import datasets, transforms, utils
36
37 import svrt
38
39 ######################################################################
40
41 parser = argparse.ArgumentParser(
42     description = 'Simple convnet test on the SVRT.',
43     formatter_class = argparse.ArgumentDefaultsHelpFormatter
44 )
45
46 parser.add_argument('--nb_train_samples',
47                     type = int, default = 100000,
48                     help = 'How many samples for train')
49
50 parser.add_argument('--nb_test_samples',
51                     type = int, default = 10000,
52                     help = 'How many samples for test')
53
54 parser.add_argument('--nb_epochs',
55                     type = int, default = 25,
56                     help = 'How many training epochs')
57
58 parser.add_argument('--log_file',
59                     type = str, default = 'cnn-svrt.log',
60                     help = 'Log file name')
61
62 args = parser.parse_args()
63
64 ######################################################################
65
66 log_file = open(args.log_file, 'w')
67
68 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
69
70 def log_string(s):
71     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
72         str(problem_number) + ' ' + s
73     log_file.write(s + '\n')
74     log_file.flush()
75     print(s)
76
77 ######################################################################
78
79 def generate_set(p, n):
80     target = torch.LongTensor(n).bernoulli_(0.5)
81     t = time.time()
82     input = svrt.generate_vignettes(p, target)
83     t = time.time() - t
84     log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86     return Variable(input), Variable(target)
87
88 ######################################################################
89
90 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
91
92 class Net(nn.Module):
93     def __init__(self):
94         super(Net, self).__init__()
95         self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
96         self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
97         self.fc1 = nn.Linear(500, 100)
98         self.fc2 = nn.Linear(100, 2)
99
100     def forward(self, x):
101         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
102         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
103         x = x.view(-1, 500)
104         x = fn.relu(self.fc1(x))
105         x = self.fc2(x)
106         return x
107
108 def train_model(train_input, train_target):
109     model, criterion = Net(), nn.CrossEntropyLoss()
110
111     if torch.cuda.is_available():
112         model.cuda()
113         criterion.cuda()
114
115     optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100
116
117     for k in range(0, args.nb_epochs):
118         acc_loss = 0.0
119         for b in range(0, train_input.size(0), bs):
120             output = model.forward(train_input.narrow(0, b, bs))
121             loss = criterion(output, train_target.narrow(0, b, bs))
122             acc_loss = acc_loss + loss.data[0]
123             model.zero_grad()
124             loss.backward()
125             optimizer.step()
126         log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
127
128     return model
129
130 ######################################################################
131
132 def nb_errors(model, data_input, data_target, bs = 100):
133     ne = 0
134
135     for b in range(0, data_input.size(0), bs):
136         output = model.forward(data_input.narrow(0, b, bs))
137         wta_prediction = output.data.max(1)[1].view(-1)
138
139         for i in range(0, bs):
140             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
141                 ne = ne + 1
142
143     return ne
144
145 ######################################################################
146
147 for problem_number in range(1, 24):
148     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
149     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
150
151     if torch.cuda.is_available():
152         train_input, train_target = train_input.cuda(), train_target.cuda()
153         test_input, test_target = test_input.cuda(), test_target.cuda()
154
155     mu, std = train_input.data.mean(), train_input.data.std()
156     train_input.data.sub_(mu).div_(std)
157     test_input.data.sub_(mu).div_(std)
158
159     model = train_model(train_input, train_target)
160
161     nb_train_errors = nb_errors(model, train_input, train_target)
162
163     log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
164         100 * nb_train_errors / train_input.size(0),
165         nb_train_errors,
166         train_input.size(0))
167     )
168
169     nb_test_errors = nb_errors(model, test_input, test_target)
170
171     log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
172         100 * nb_test_errors / test_input.size(0),
173         nb_test_errors,
174         test_input.size(0))
175     )
176
177 ######################################################################