Minor fixes + changed the default number of epochs to 100.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26
27 from colorama import Fore, Back, Style
28
29 import torch
30
31 from torch import optim
32 from torch import FloatTensor as Tensor
33 from torch.autograd import Variable
34 from torch import nn
35 from torch.nn import functional as fn
36 from torchvision import datasets, transforms, utils
37
38 import svrt
39
40 ######################################################################
41
42 parser = argparse.ArgumentParser(
43     description = 'Simple convnet test on the SVRT.',
44     formatter_class = argparse.ArgumentDefaultsHelpFormatter
45 )
46
47 parser.add_argument('--nb_train_samples',
48                     type = int, default = 100000,
49                     help = 'How many samples for train')
50
51 parser.add_argument('--nb_test_samples',
52                     type = int, default = 10000,
53                     help = 'How many samples for test')
54
55 parser.add_argument('--nb_epochs',
56                     type = int, default = 100,
57                     help = 'How many training epochs')
58
59 parser.add_argument('--log_file',
60                     type = str, default = 'cnn-svrt.log',
61                     help = 'Log file name')
62
63 args = parser.parse_args()
64
65 ######################################################################
66
67 log_file = open(args.log_file, 'w')
68
69 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
70
71 def log_string(s):
72     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
73     log_file.write(s + '\n')
74     log_file.flush()
75     print(s)
76
77 ######################################################################
78
79 def generate_set(p, n):
80     target = torch.LongTensor(n).bernoulli_(0.5)
81     t = time.time()
82     input = svrt.generate_vignettes(p, target)
83     t = time.time() - t
84     log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86     return Variable(input), Variable(target)
87
88 ######################################################################
89
90 # Afroze's ShallowNet
91
92 #                       map size   nb. maps
93 #                     ----------------------
94 #    input                128x128    1
95 # -- conv(21x21 x 6)   -> 108x108    6
96 # -- max(2x2)          -> 54x54      6
97 # -- conv(19x19 x 16)  -> 36x36      16
98 # -- max(2x2)          -> 18x18      16
99 # -- conv(18x18 x 120) -> 1x1        120
100 # -- reshape           -> 120        1
101 # -- full(120x84)      -> 84         1
102 # -- full(84x2)        -> 2          1
103
104 class Net(nn.Module):
105     def __init__(self):
106         super(Net, self).__init__()
107         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110         self.fc1 = nn.Linear(120, 84)
111         self.fc2 = nn.Linear(84, 2)
112
113     def forward(self, x):
114         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116         x = fn.relu(self.conv3(x))
117         x = x.view(-1, 120)
118         x = fn.relu(self.fc1(x))
119         x = self.fc2(x)
120         return x
121
122 def train_model(train_input, train_target):
123     model, criterion = Net(), nn.CrossEntropyLoss()
124
125     nb_parameters = 0
126     for p in model.parameters():
127         nb_parameters += p.numel()
128     log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
129
130     if torch.cuda.is_available():
131         model.cuda()
132         criterion.cuda()
133
134     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
135
136     for k in range(0, args.nb_epochs):
137         acc_loss = 0.0
138         for b in range(0, train_input.size(0), bs):
139             output = model.forward(train_input.narrow(0, b, bs))
140             loss = criterion(output, train_target.narrow(0, b, bs))
141             acc_loss = acc_loss + loss.data[0]
142             model.zero_grad()
143             loss.backward()
144             optimizer.step()
145         log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
146
147     return model
148
149 ######################################################################
150
151 def nb_errors(model, data_input, data_target, bs = 100):
152     ne = 0
153
154     for b in range(0, data_input.size(0), bs):
155         output = model.forward(data_input.narrow(0, b, bs))
156         wta_prediction = output.data.max(1)[1].view(-1)
157
158         for i in range(0, bs):
159             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
160                 ne = ne + 1
161
162     return ne
163
164 ######################################################################
165
166 for arg in vars(args):
167     log_string('ARGUMENT ' + str(arg) + ' ' + str(getattr(args, arg)))
168
169 for problem_number in range(1, 24):
170     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
171     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
172
173     if torch.cuda.is_available():
174         train_input, train_target = train_input.cuda(), train_target.cuda()
175         test_input, test_target = test_input.cuda(), test_target.cuda()
176
177     mu, std = train_input.data.mean(), train_input.data.std()
178     train_input.data.sub_(mu).div_(std)
179     test_input.data.sub_(mu).div_(std)
180
181     model = train_model(train_input, train_target)
182
183     nb_train_errors = nb_errors(model, train_input, train_target)
184
185     log_string('TRAIN_ERROR {:d} {:.02f}% {:d} {:d}'.format(
186         problem_number,
187         100 * nb_train_errors / train_input.size(0),
188         nb_train_errors,
189         train_input.size(0))
190     )
191
192     nb_test_errors = nb_errors(model, test_input, test_target)
193
194     log_string('TEST_ERROR {:d} {:.02f}% {:d} {:d}'.format(
195         problem_number,
196         100 * nb_test_errors / test_input.size(0),
197         nb_test_errors,
198         test_input.size(0))
199     )
200
201 ######################################################################