Recoded to UTF-8.
[cmim.git] / classifier.h
1
2 //////////////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify         //
4 // it under the terms of the version 3 of the GNU General Public License        //
5 // as published by the Free Software Foundation.                                //
6 //                                                                              //
7 // This program is distributed in the hope that it will be useful, but          //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of                   //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU             //
10 // General Public License for more details.                                     //
11 //                                                                              //
12 // You should have received a copy of the GNU General Public License            //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.         //
14 //                                                                              //
15 // Written by Francois Fleuret                                                  //
16 // Copyright (C) Ecole Polytechnique Federale de Lausanne                       //
17 // Contact <francois.fleuret@epfl.ch> for comments & bug reports                //
18 //////////////////////////////////////////////////////////////////////////////////
19
20 // $Id: classifier.h,v 1.3 2007-08-23 08:36:50 fleuret Exp $
21
22 #ifndef CLASSIFIER_H
23 #define CLASSIFIER_H
24
25 using namespace std;
26
27 #include <iostream>
28 #include <fstream>
29
30 #include "misc.h"
31 #include "fastentropy.h"
32
33 class FeatureSelector;
34
35 class DataSet {
36   // This keeps track of the number of references
37   struct RawData {
38     int nrefs;
39     uint32_t *x, *y;
40   };
41 public:
42   int nb_samples, nb_features;
43   int size;
44   RawData *raw;
45   uint32_t *y_va, **x_va;
46   DataSet(istream &is);
47   DataSet(int nb_samples, int nb_features);
48   DataSet(const DataSet &ds);
49   DataSet(const DataSet &ds, const FeatureSelector &fs);
50   DataSet(const DataSet &ds, bool *selected_samples);
51   DataSet &operator = (const DataSet &ds);
52   ~DataSet();
53   void copy(const DataSet &ds);
54   void save_ascii(ostream &os);
55 };
56
57 //////////////////////////////////////////////////////////////////////
58 // The classifier ////////////////////////////////////////////////////
59 //////////////////////////////////////////////////////////////////////
60
61 class Classifier {
62 public:
63   enum { ID_LINEAR };
64   static Classifier *load(istream &is);
65   virtual ~Classifier();
66   virtual void predict(const DataSet &ds, float *result) = 0;
67   virtual void save(ostream &out) = 0;
68   virtual void inner_load(istream &in) = 0;
69 };
70
71 class FeatureSelector {
72 public:
73   int nb_selected_features;
74   int *selected_index;
75
76   // Those remains from the feature selection process. They can be
77   // used as-is in the case of adaboost
78   float *weights;
79
80   FeatureSelector(istream &is);
81   FeatureSelector(int nb_selected_features);
82   ~FeatureSelector();
83
84   // Selects features according to the Conditional Mutual Information Maximisation
85   void cmim(const DataSet &ds);
86
87   // Selects features according to the Mutual Information Maximisation
88   void mim(const DataSet &ds);
89
90   // Selects random features
91   void random(const DataSet &ds);
92
93   void save(ostream &os);
94 };
95
96 class LinearClassifier : public Classifier {
97   int nb_features;
98   float *weights;
99   float bias;
100 public:
101   LinearClassifier();
102   LinearClassifier(int nb_features);
103   virtual ~LinearClassifier();
104
105   void compute_bayesian_weights(int nb_samples, uint32_t *y_va, uint32_t **x_va);
106   void compute_bias(int nb_samples, uint32_t *y_va, uint32_t **x_va, bool balanced_error);
107
108   void learn_bayesian(const DataSet &ds, bool balanced_error);
109   void learn_perceptron(const DataSet &ds, bool balanced_error);
110
111   virtual void predict(const DataSet &ds, float *result);
112   virtual void save(ostream &out);
113   virtual void inner_load(istream &is);
114 };
115
116 void compute_error_rates(FeatureSelector *selector, Classifier *classifier,
117                          const DataSet &testing_set, int &n00, int &n01, int &n10, int &n11, float *result);
118
119 #endif