Automatic commit
[svrt.git] / naive_bayesian_classifier.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "naive_bayesian_classifier.h"
26 #include "classifier_reader.h"
27
28 NaiveBayesianClassifier::NaiveBayesianClassifier() { }
29
30 NaiveBayesianClassifier::~NaiveBayesianClassifier() { }
31
32 const char *NaiveBayesianClassifier::name() {
33   return "NAIVE_BAYESIAN";
34 }
35
36 void NaiveBayesianClassifier::train(int nb_vignettes, Vignette *vignettes, int *labels) {
37   for(int k = 0; k < Vignette::width * Vignette::height * Vignette::nb_grayscales; k++) {
38     proba_given_0[k] = 0;
39     proba_given_1[k] = 0;
40   }
41
42   int nb_0 = 0, nb_1 = 0;
43
44   global.bar.init(&cout, nb_vignettes);
45   for(int n = 0; n < nb_vignettes; n++) {
46     if(labels[n] == 1) {
47       nb_1++;
48       for(int k = 0; k < Vignette::width * Vignette::height; k++) {
49         proba_given_1[k * Vignette::nb_grayscales + vignettes[n].content[k]] += 1.0;
50       }
51     } else {
52       nb_0++;
53       for(int k = 0; k < Vignette::width * Vignette::height; k++) {
54         proba_given_0[k * Vignette::nb_grayscales + vignettes[n].content[k]] += 1.0;
55       }
56     }
57     global.bar.refresh(&cout, n);
58   }
59   global.bar.finish(&cout);
60
61   for(int k = 0; k < Vignette::width * Vignette::height * Vignette::nb_grayscales; k++) {
62     proba_given_0[k] /= scalar_t(nb_0);
63     proba_given_1[k] /= scalar_t(nb_1);
64   }
65 }
66
67 scalar_t NaiveBayesianClassifier::classify(Vignette *vignette) {
68   scalar_t result = 0.0;
69
70   for(int k = 0; k < Vignette::width * Vignette::height; k++) {
71     result += log(proba_given_1[k * Vignette::nb_grayscales + vignette->content[k]])
72       - log(proba_given_0[k * Vignette::nb_grayscales + vignette->content[k]]);
73   }
74
75   return result;
76 }
77
78 void NaiveBayesianClassifier::read(istream *in) {
79   in->read((char *) proba_given_0, sizeof(scalar_t) * Vignette::width * Vignette::height * Vignette::nb_grayscales);
80   in->read((char *) proba_given_1, sizeof(scalar_t) * Vignette::width * Vignette::height * Vignette::nb_grayscales);
81 }
82
83 void NaiveBayesianClassifier::write(ostream *out) {
84   int t;
85   t = CT_NAIVE_BAYESIAN;
86   write_var(out, &t);
87   out->write((char *) proba_given_0, sizeof(scalar_t) * Vignette::width * Vignette::height * Vignette::nb_grayscales);
88   out->write((char *) proba_given_1, sizeof(scalar_t) * Vignette::width * Vignette::height * Vignette::nb_grayscales);
89 }