3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
39 from vignette_set import VignetteSet, CompressedVignetteSet
41 ######################################################################
43 parser = argparse.ArgumentParser(
44 description = 'Simple convnet test on the SVRT.',
45 formatter_class = argparse.ArgumentDefaultsHelpFormatter
48 parser.add_argument('--nb_train_batches',
49 type = int, default = 1000,
50 help = 'How many samples for train')
52 parser.add_argument('--nb_test_batches',
53 type = int, default = 100,
54 help = 'How many samples for test')
56 parser.add_argument('--nb_epochs',
57 type = int, default = 50,
58 help = 'How many training epochs')
60 parser.add_argument('--batch_size',
61 type = int, default = 100,
62 help = 'Mini-batch size')
64 parser.add_argument('--log_file',
65 type = str, default = 'cnn-svrt.log',
66 help = 'Log file name')
68 parser.add_argument('--compress_vignettes',
69 action='store_true', default = False,
70 help = 'Use lossless compression to reduce the memory footprint')
72 args = parser.parse_args()
74 ######################################################################
76 log_file = open(args.log_file, 'w')
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
81 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82 log_file.write(s + '\n')
86 ######################################################################
91 # ----------------------
93 # -- conv(21x21 x 6) -> 108x108 6
94 # -- max(2x2) -> 54x54 6
95 # -- conv(19x19 x 16) -> 36x36 16
96 # -- max(2x2) -> 18x18 16
97 # -- conv(18x18 x 120) -> 1x1 120
99 # -- full(120x84) -> 84 1
100 # -- full(84x2) -> 2 1
102 class AfrozeShallowNet(nn.Module):
104 super(AfrozeShallowNet, self).__init__()
105 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
106 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
107 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
108 self.fc1 = nn.Linear(120, 84)
109 self.fc2 = nn.Linear(84, 2)
110 self.name = 'shallownet'
112 def forward(self, x):
113 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
114 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
115 x = fn.relu(self.conv3(x))
117 x = fn.relu(self.fc1(x))
121 ######################################################################
123 def train_model(model, train_set):
124 batch_size = args.batch_size
125 criterion = nn.CrossEntropyLoss()
127 if torch.cuda.is_available():
130 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
132 for e in range(0, args.nb_epochs):
134 for b in range(0, train_set.nb_batches):
135 input, target = train_set.get_batch(b)
136 output = model.forward(Variable(input))
137 loss = criterion(output, Variable(target))
138 acc_loss = acc_loss + loss.data[0]
142 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
146 ######################################################################
148 def nb_errors(model, data_set):
150 for b in range(0, data_set.nb_batches):
151 input, target = data_set.get_batch(b)
152 output = model.forward(Variable(input))
153 wta_prediction = output.data.max(1)[1].view(-1)
155 for i in range(0, data_set.batch_size):
156 if wta_prediction[i] != target[i]:
161 ######################################################################
163 for arg in vars(args):
164 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
166 for problem_number in range(1, 24):
167 if args.compress_vignettes:
168 train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
169 test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
171 train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
172 test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
174 model = AfrozeShallowNet()
176 if torch.cuda.is_available():
180 for p in model.parameters():
181 nb_parameters += p.numel()
182 log_string('nb_parameters {:d}'.format(nb_parameters))
184 model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param'
187 model.load_state_dict(torch.load(model_filename))
188 log_string('loaded_model ' + model_filename)
190 log_string('training_model')
191 train_model(model, train_set)
192 torch.save(model.state_dict(), model_filename)
193 log_string('saved_model ' + model_filename)
195 nb_train_errors = nb_errors(model, train_set)
197 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
199 100 * nb_train_errors / train_set.nb_samples,
201 train_set.nb_samples)
204 nb_test_errors = nb_errors(model, test_set)
206 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
208 100 * nb_test_errors / test_set.nb_samples,
213 ######################################################################