automatic commit
[folded-ctf.git] / decision_tree.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published
13  *  by the Free Software Foundation, either version 3 of the License,
14  *  or (at your option) any later version.
15  *
16  *  folded-ctf is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include "decision_tree.h"
27 #include "fusion_sort.h"
28
29 DecisionTree::DecisionTree() {
30   _feature_index = -1;
31   _threshold = 0;
32   _weight = 0;
33   _subtree_greater = 0;
34   _subtree_lesser = 0;
35 }
36
37 DecisionTree::~DecisionTree() {
38   if(_subtree_lesser)
39     delete _subtree_lesser;
40   if(_subtree_greater)
41     delete _subtree_greater;
42 }
43
44 int DecisionTree::nb_leaves() {
45   if(_subtree_lesser ||_subtree_greater)
46     return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
47   else
48     return 1;
49 }
50
51 int DecisionTree::depth() {
52   if(_subtree_lesser ||_subtree_greater)
53     return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
54   else
55     return 1;
56 }
57
58 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
59   if(_subtree_lesser && _subtree_greater) {
60     if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
61       return _subtree_lesser->response(sample_set, n_sample);
62     else
63       return _subtree_greater->response(sample_set, n_sample);
64   } else {
65     return _weight;
66   }
67 }
68
69 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
70
71   int nb_samples = sample_set->nb_samples();
72
73   scalar_t *responses = new scalar_t[nb_samples];
74   int *indexes = new int[nb_samples];
75   int *sorted_indexes = new int[nb_samples];
76
77   scalar_t max_abs_sum = 0;
78   _feature_index = -1;
79
80   for(int f = 0; f < sample_set->nb_features(); f++) {
81     scalar_t sum = 0;
82
83     for(int s = 0; s < nb_samples; s++) {
84       indexes[s] = s;
85       responses[s] = sample_set->feature_value(s, f);
86       sum += loss_derivatives[s];
87     }
88
89     indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
90
91     int t, u = sorted_indexes[0];
92     for(int s = 0; s < nb_samples - 1; s++) {
93       t = u;
94       u = sorted_indexes[s + 1];
95       sum -= 2 * loss_derivatives[t];
96
97       if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
98         max_abs_sum = abs(sum);
99         _feature_index = f;
100         _threshold = (responses[t] + responses[u])/2;
101       }
102     }
103   }
104
105   delete[] indexes;
106   delete[] sorted_indexes;
107   delete[] responses;
108 }
109
110 void DecisionTree::train(LossMachine *loss_machine,
111                          SampleSet *sample_set,
112                          scalar_t *current_responses,
113                          scalar_t *loss_derivatives,
114                          int depth) {
115
116   if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
117     cerr << "You can not re-train a tree." << endl;
118     abort();
119   }
120
121   int nb_samples = sample_set->nb_samples();
122
123   int nb_pos = 0, nb_neg = 0;
124   for(int s = 0; s < sample_set->nb_samples(); s++) {
125     if(sample_set->label(s) > 0) nb_pos++;
126     else if(sample_set->label(s) < 0) nb_neg++;
127   }
128
129   (*global.log_stream) << "Training tree" << endl;
130   (*global.log_stream) << "  nb_samples " << nb_samples << endl;
131   (*global.log_stream) << "  depth " << depth << endl;
132   (*global.log_stream) << "  nb_pos = " << nb_pos << endl;
133   (*global.log_stream) << "  nb_neg = " << nb_neg << endl;
134
135   if(depth >= global.tree_depth_max)
136     (*global.log_stream) << "  Maximum depth reached." << endl;
137   if(nb_pos < min_nb_samples_for_split)
138     (*global.log_stream) << "  Not enough positive samples." << endl;
139   if(nb_neg < min_nb_samples_for_split)
140     (*global.log_stream) << "  Not enough negative samples." << endl;
141
142   if(depth < global.tree_depth_max &&
143      nb_pos >= min_nb_samples_for_split &&
144      nb_neg >= min_nb_samples_for_split) {
145
146     pick_best_split(sample_set, loss_derivatives);
147
148     if(_feature_index >= 0) {
149       int indexes[nb_samples];
150       scalar_t *parted_current_responses = new scalar_t[nb_samples];
151       scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
152
153       int nb_lesser = 0, nb_greater = 0;
154       int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
155
156       for(int s = 0; s < nb_samples; s++) {
157         if(sample_set->feature_value(s, _feature_index) < _threshold) {
158           indexes[nb_lesser] = s;
159           parted_current_responses[nb_lesser] = current_responses[s];
160           parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
161
162           if(sample_set->label(s) > 0)
163             nb_lesser_pos++;
164           else if(sample_set->label(s) < 0)
165             nb_lesser_neg++;
166
167           nb_lesser++;
168         } else {
169           nb_greater++;
170
171           indexes[nb_samples - nb_greater] = s;
172           parted_current_responses[nb_samples - nb_greater] = current_responses[s];
173           parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
174
175           if(sample_set->label(s) > 0)
176             nb_greater_pos++;
177           else if(sample_set->label(s) < 0)
178             nb_greater_neg++;
179         }
180       }
181
182       if((nb_lesser_pos >= min_nb_samples_for_split ||
183           nb_lesser_neg >= min_nb_samples_for_split) &&
184          (nb_greater_pos >= min_nb_samples_for_split ||
185           nb_greater_neg >= min_nb_samples_for_split)) {
186
187         _subtree_lesser = new DecisionTree();
188
189         {
190           SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
191
192           _subtree_lesser->train(loss_machine,
193                                  &sub_sample_set,
194                                  parted_current_responses,
195                                  parted_loss_derivatives,
196                                  depth + 1);
197         }
198
199         _subtree_greater = new DecisionTree();
200
201         {
202           SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
203
204           _subtree_greater->train(loss_machine,
205                                   &sub_sample_set,
206                                   parted_current_responses + nb_lesser,
207                                   parted_loss_derivatives + nb_lesser,
208                                   depth + 1);
209         }
210       }
211
212       delete[] parted_current_responses;
213       delete[] parted_loss_derivatives;
214     } else {
215       (*global.log_stream) << "Could not find a feature for split." << endl;
216     }
217   }
218
219   if(!(_subtree_greater && _subtree_lesser)) {
220     scalar_t *tmp_responses = new scalar_t[nb_samples];
221     for(int s = 0; s < nb_samples; s++)
222       tmp_responses[s] = 1;
223
224     _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
225
226     const scalar_t max_weight = 10.0;
227
228     if(_weight > max_weight) {
229       _weight = max_weight;
230     } else if(_weight < - max_weight) {
231       _weight = - max_weight;
232     }
233
234     (*global.log_stream) << "  _weight " << _weight << endl;
235
236     delete[] tmp_responses;
237   }
238 }
239
240 void DecisionTree::train(LossMachine *loss_machine,
241                  SampleSet *sample_set,
242                  scalar_t *current_responses) {
243
244   scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
245
246   loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
247
248   train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
249
250   delete[] loss_derivatives;
251 }
252
253 //////////////////////////////////////////////////////////////////////
254
255 void DecisionTree::tag_used_features(bool *used) {
256   if(_subtree_lesser && _subtree_greater) {
257     used[_feature_index] = true;
258     _subtree_lesser->tag_used_features(used);
259     _subtree_greater->tag_used_features(used);
260   }
261 }
262
263 void DecisionTree::re_index_features(int *new_indexes) {
264   if(_subtree_lesser && _subtree_greater) {
265     _feature_index = new_indexes[_feature_index];
266     _subtree_lesser->re_index_features(new_indexes);
267     _subtree_greater->re_index_features(new_indexes);
268   }
269 }
270
271 //////////////////////////////////////////////////////////////////////
272
273 void DecisionTree::read(istream *is) {
274   if(_subtree_lesser || _subtree_greater) {
275     cerr << "You can not read in an existing tree." << endl;
276     abort();
277   }
278
279   read_var(is, &_feature_index);
280   read_var(is, &_threshold);
281   read_var(is, &_weight);
282
283   int split;
284   read_var(is, &split);
285
286   if(split) {
287     _subtree_lesser = new DecisionTree();
288     _subtree_lesser->read(is);
289     _subtree_greater = new DecisionTree();
290     _subtree_greater->read(is);
291   }
292 }
293
294 void DecisionTree::write(ostream *os) {
295
296   write_var(os, &_feature_index);
297   write_var(os, &_threshold);
298   write_var(os, &_weight);
299
300   int split;
301   if(_subtree_lesser && _subtree_greater) {
302     split = 1;
303     write_var(os, &split);
304     _subtree_lesser->write(os);
305     _subtree_greater->write(os);
306   } else {
307     split = 0;
308     write_var(os, &split);
309   }
310 }