2 //////////////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation. //
7 // This program is distributed in the hope that it will be useful, but //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
10 // General Public License for more details. //
12 // You should have received a copy of the GNU General Public License //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>. //
15 // Written by Francois Fleuret //
16 // (C) Ecole Polytechnique Federale de Lausanne //
17 // Contact <francois.fleuret@epfl.ch> for comments & bug reports //
18 //////////////////////////////////////////////////////////////////////////////////
20 // $Id: cmim.cc,v 1.4 2007-08-23 08:36:50 fleuret Exp $
22 // This software was developped on GNU/Linux systems with many GPL
23 // tools including emacs, gcc, gdb, and bash (see http://www.fsf.org).
29 ./cmim --feature-selection cmim --classifier bayesian --error ber --train ./train.dat ./classifier.nb 100
30 ./cmim --test ./test.dat ./classifier.nb ./result.dat
44 #include "classifier.h"
46 #define BUFFER_SIZE 256
48 FeatureSelector *selector;
49 Classifier *classifier;
50 char classifier_type[BUFFER_SIZE] = "bayesian";
51 char feature_selection_type[BUFFER_SIZE] = "cmim";
52 float reg_param = 0.0;
54 bool balanced_error = false;
55 int nb_selected_features = 100;
57 void check_opt(int argc, char **argv, int n_opt, int n, char *help) {
59 cerr << "Missing argument for " << argv[n_opt] << ".\n";
60 cerr << "Expecting " << help << ".\n";
65 void train(const DataSet &training_set) {
66 timeval tv_start, tv_end;
70 cout << "Selecting features with " << feature_selection_type;
72 gettimeofday(&tv_start, 0);
77 selector = new FeatureSelector(nb_selected_features);
79 if(strcmp(feature_selection_type, "cmim") == 0) selector->cmim(training_set);
80 else if(strcmp(feature_selection_type, "mim") == 0) selector->mim(training_set);
81 else if(strcmp(feature_selection_type, "random") == 0) selector->random(training_set);
83 cerr << "Unknown feature selection type " << feature_selection_type << "\n";
88 gettimeofday(&tv_end, 0);
90 << (float(tv_end.tv_sec - tv_start.tv_sec) * 1000 + float(tv_end.tv_usec - tv_start.tv_usec)/1000)
92 gettimeofday(&tv_start, 0);
93 cout << "Learning with " << classifier_type;
99 DataSet reduced_training_set(training_set, *selector);
101 if(strcmp(classifier_type, "bayesian") == 0) {
102 LinearClassifier *tmp = new LinearClassifier(nb_selected_features);
103 tmp->learn_bayesian(reduced_training_set, balanced_error);
107 else if(strcmp(classifier_type, "perceptron") == 0) {
108 LinearClassifier *tmp = new LinearClassifier(nb_selected_features);
109 tmp->learn_perceptron(reduced_training_set, balanced_error);
114 cerr << "Unknown learning method type " << classifier_type << "\n";
119 gettimeofday(&tv_end, 0);
121 << (float(tv_end.tv_sec - tv_start.tv_sec) * 1000 + float(tv_end.tv_usec - tv_start.tv_usec)/1000)
128 int main(int argc, char **argv) {
129 bool arg_error = false;
132 while(i < argc && !arg_error) {
134 //////////////////////////////////////////////////////////////////////
135 // Parameters ////////////////////////////////////////////////////////
136 //////////////////////////////////////////////////////////////////////
138 if(strcmp(argv[i], "--silent") == 0) {
143 else if(strcmp(argv[i], "--feature-selection") == 0) {
144 check_opt(argc, argv, i, 1, "<random|mim|cmim>");
145 strncpy(feature_selection_type, argv[i+1], BUFFER_SIZE);
149 else if(strcmp(argv[i], "--classifier") == 0) {
150 check_opt(argc, argv, i, 1, "<bayesian|perceptron>");
151 strncpy(classifier_type, argv[i+1], BUFFER_SIZE);
155 else if(strcmp(argv[i], "--error") == 0) {
156 check_opt(argc, argv, i, 1, "<standard|ber>");
157 if(strcmp(argv[i+1], "standard") == 0) balanced_error = false;
158 else if(strcmp(argv[i+1], "ber") == 0) balanced_error = true;
160 cerr << "Unknown error type " << argv[i+1] << "!\n";
166 else if(strcmp(argv[i], "--nb-features") == 0) {
167 check_opt(argc, argv, i, 1, "<int: nb features>");
168 nb_selected_features = atoi(argv[i+1]);
169 if(nb_selected_features <= 0) {
170 cerr << "Unconsistent number of selected features (" << nb_selected_features << ").\n";
176 //////////////////////////////////////////////////////////////////////
177 // Training //////////////////////////////////////////////////////////
178 //////////////////////////////////////////////////////////////////////
180 else if(strcmp(argv[i], "--cross-validation") == 0) {
181 check_opt(argc, argv, i, 3, "<file: data set> <int: nb test samples> <int: nb loops>");
183 cout << "Loading data.\n";
187 ifstream complete_data(argv[i+1]);
188 if(complete_data.fail()) {
189 cerr << "Can not open " << argv[i+1] << " for reading!\n";
193 int nb_for_test = atoi(argv[i+2]);
194 if(nb_for_test <= 0) {
195 cerr << "Unconsistent number of samples for test (" << nb_selected_features << ").\n";
199 int nb_cv_loops = atoi(argv[i+3]);
200 if(nb_cv_loops <= 0) {
201 cerr << "Unconsistent number of cross-validation loops (" << nb_cv_loops << ").\n";
205 DataSet complete_set(complete_data);
207 int n00_test = 0, n01_test = 0, n10_test = 0, n11_test = 0;
208 int n00_train = 0, n01_train = 0, n10_train = 0, n11_train = 0;
210 for(int ncv = 0; ncv < nb_cv_loops; ncv++) {
211 bool for_test[complete_set.nb_samples];
213 for(int s = 0; s < complete_set.nb_samples; s++) for_test[s] = false;
214 for(int i = 0; i < nb_for_test; i++) {
217 s = int(drand48() * complete_set.nb_samples);
218 } while(for_test[s]);
222 DataSet testing_set(complete_set, for_test);
223 for(int s = 0; s < complete_set.nb_samples; s++) for_test[s] = !for_test[s];
224 DataSet training_set(complete_set, for_test);
228 int n00, n01, n10, n11;
231 float result[training_set.nb_samples];
232 compute_error_rates(selector, classifier, training_set, n00, n01, n10, n11, result);
233 n00_train += n00; n01_train += n01; n10_train += n10; n11_train += n11;
237 float result[testing_set.nb_samples];
238 compute_error_rates(selector, classifier, testing_set, n00, n01, n10, n11, result);
239 n00_test += n00; n01_test += n01; n10_test += n10; n11_test += n11;
247 cout << "BER [" << nb_cv_loops << " loops] "
248 << " train " << 0.5 * (float(n01_train)/float(n00_train + n01_train) + float(n10_train)/float(n10_train + n11_train))
249 << " test " << 0.5 * (float(n01_test)/float(n00_test + n01_test) + float(n10_test)/float(n10_test + n11_test)) << "\n";
251 cout << "Error [" << nb_cv_loops << " loops] "
252 << " train " << float(n01_train + n10_train)/float(n00_train + n01_train + n10_train + n11_train)
253 << " test " << float(n01_test + n10_test)/float(n00_test + n01_test + n10_test + n11_test) << "\n";
259 //////////////////////////////////////////////////////////////////////
261 else if(strcmp(argv[i], "--train") == 0) {
262 check_opt(argc, argv, i, 2, "<file: data set> <file: classifier>");
265 cout << "Loading data.\n";
269 ifstream training_data(argv[i+1]);
270 if(training_data.fail()) {
271 cerr << "Can not open " << argv[i+1] << " for reading!\n";
275 DataSet training_set(training_data);
277 //////////////////////////////////////////////////////////////////////
278 // Learning with CMIM + naive Bayesian ///////////////////////////////
279 //////////////////////////////////////////////////////////////////////
283 //////////////////////////////////////////////////////////////////////
284 // Finishing and saving //////////////////////////////////////////////
285 //////////////////////////////////////////////////////////////////////
287 if(verbose) cout << "Saving the classifier in [" << argv[i+2] << "].\n";
288 ofstream classifier_out(argv[i+2]);
289 if(classifier_out.fail()) {
290 cerr << "Can not open " << argv[i+2] << " for writing!\n";
294 selector->save(classifier_out);
295 classifier->save(classifier_out);
303 //////////////////////////////////////////////////////////////////////
304 // Test //////////////////////////////////////////////////////////////
305 //////////////////////////////////////////////////////////////////////
307 else if(strcmp(argv[i], "--test") == 0) {
308 check_opt(argc, argv, i, 3, "<file: classifier> <file: data set> <file: result>");
310 // Load the classifier
312 if(verbose) cout << "Loading the classifier from [" << argv[i+1] << "].\n";
314 ifstream classifier_in(argv[i+1]);
315 if(classifier_in.fail()) {
316 cerr << "Can not open " << argv[i+1] << " for reading!\n";
320 selector = new FeatureSelector(classifier_in);
321 classifier = Classifier::load(classifier_in);
323 // Load the testing data
325 ifstream testing_data(argv[i+2]);
326 if(testing_data.fail()) {
327 cerr << "Can not open " << argv[i+2] << " for reading!\n";
331 ofstream result_out(argv[i+3]);
332 if(result_out.fail()) {
333 cerr << "Can not open " << argv[i+3] << " for writing!\n";
337 DataSet testing_set(testing_data);
339 // Compute the predicted responses
341 int n00, n01, n10, n11;
342 float result[testing_set.nb_samples];
343 compute_error_rates(selector, classifier, testing_set, n00, n01, n10, n11, result);
345 for(int s = 0; s < testing_set.nb_samples; s++)
346 result_out << result[s] << "\n";
348 cout << "ERROR " << float(n01 + n10)/float(n00 + n01 + n10 + n11) << "\n";
349 cout << "BER " << 0.5 * (float(n01)/float(n00 + n01) + float(n10)/float(n10 + n11)) << "\n";
350 cout << "FN " << float(n10)/float(n10+n11) << "\n";
351 cout << "FP " << float(n01)/float(n01+n00) << "\n";
352 cout << "real_0_predicted_0 " << n00 << "\n";
353 cout << "real_0_predicted_1 " << n01 << "\n";
354 cout << "real_1_predicted_0 " << n10 << "\n";
355 cout << "real_1_predicted_1 " << n11 << "\n";
363 else arg_error = true;
367 cerr << "Conditional Mutual Information Maximization\n";
368 cerr << "Written by François Fleuret (c) EPFL 2004\n";
369 cerr << "Comments and bug reports to <francois.fleuret@epfl.ch>\n";
371 cerr << "Usage: " << argv[0] << "\n";
372 cerr << "--silent\n";
373 cerr << "--feature-selection <random|mim|cmim>\n";
374 cerr << "--classifier <bayesian|perceptron>\n";
375 cerr << "--error <standard|ber>\n";
376 cerr << "--nb-features <int: nb of features>\n";
377 cerr << "--cross-validation <file: data set> <int: nb test samples> <int: nb loops>\n";
378 cerr << "--train <file: data set> <file: classifier>\n";
379 cerr << "--test <file: classifier> <file: data set> <file: result>\n";