automatic commit
[folded-ctf.git] / decision_tree.h
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published
13  *  by the Free Software Foundation, either version 3 of the License,
14  *  or (at your option) any later version.
15  *
16  *  folded-ctf is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 /*
27
28   An implementation of the classifier with a decision tree. Each node
29   simply thresholds one of the component, and is chosen for maximum
30   loss reduction locally during training. The leaves are labelled with
31   the classifier response, which is chosen again for maximum loss
32   reduction.
33
34  */
35
36 #ifndef DECISION_TREE_H
37 #define DECISION_TREE_H
38
39 #include "misc.h"
40 #include "classifier.h"
41 #include "sample_set.h"
42 #include "loss_machine.h"
43
44 class DecisionTree : public Classifier {
45
46   static const int min_nb_samples_for_split = 5;
47
48   int _feature_index;
49   scalar_t _threshold;
50   scalar_t _weight;
51
52   DecisionTree *_subtree_lesser, *_subtree_greater;
53
54   void pick_best_split(SampleSet *sample_set,
55                        scalar_t *loss_derivatives);
56
57   void train(LossMachine *loss_machine,
58              SampleSet *sample_set,
59              scalar_t *current_responses,
60              scalar_t *loss_derivatives,
61              int depth);
62
63 public:
64
65   DecisionTree();
66   ~DecisionTree();
67
68   int nb_leaves();
69   int depth();
70
71   scalar_t response(SampleSet *sample_set, int n_sample);
72
73   void train(LossMachine *loss_machine,
74              SampleSet *sample_set,
75              scalar_t *current_responses);
76
77   void tag_used_features(bool *used);
78   void re_index_features(int *new_indexes);
79
80   void read(istream *is);
81   void write(ostream *os);
82 };
83
84 #endif