automatic commit
[folded-ctf.git] / decision_tree.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "decision_tree.h"
20 #include "fusion_sort.h"
21
22 DecisionTree::DecisionTree() {
23   _feature_index = -1;
24   _threshold = 0;
25   _weight = 0;
26   _subtree_greater = 0;
27   _subtree_lesser = 0;
28 }
29
30 DecisionTree::~DecisionTree() {
31   if(_subtree_lesser)
32     delete _subtree_lesser;
33   if(_subtree_greater)
34     delete _subtree_greater;
35 }
36
37 int DecisionTree::nb_leaves() {
38   if(_subtree_lesser ||_subtree_greater)
39     return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
40   else
41     return 1;
42 }
43
44 int DecisionTree::depth() {
45   if(_subtree_lesser ||_subtree_greater)
46     return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
47   else
48     return 1;
49 }
50
51 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
52   if(_subtree_lesser && _subtree_greater) {
53     if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
54       return _subtree_lesser->response(sample_set, n_sample);
55     else
56       return _subtree_greater->response(sample_set, n_sample);
57   } else {
58     return _weight;
59   }
60 }
61
62 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
63
64   int nb_samples = sample_set->nb_samples();
65
66   scalar_t *responses = new scalar_t[nb_samples];
67   int *indexes = new int[nb_samples];
68   int *sorted_indexes = new int[nb_samples];
69
70   scalar_t max_abs_sum = 0;
71   _feature_index = -1;
72
73   for(int f = 0; f < sample_set->nb_features(); f++) {
74     scalar_t sum = 0;
75
76     for(int s = 0; s < nb_samples; s++) {
77       indexes[s] = s;
78       responses[s] = sample_set->feature_value(s, f);
79       sum += loss_derivatives[s];
80     }
81
82     indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
83
84     int t, u = sorted_indexes[0];
85     for(int s = 0; s < nb_samples - 1; s++) {
86       t = u;
87       u = sorted_indexes[s + 1];
88       sum -= 2 * loss_derivatives[t];
89
90       if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
91         max_abs_sum = abs(sum);
92         _feature_index = f;
93         _threshold = (responses[t] + responses[u])/2;
94       }
95     }
96   }
97
98   delete[] indexes;
99   delete[] sorted_indexes;
100   delete[] responses;
101 }
102
103 void DecisionTree::train(LossMachine *loss_machine,
104                          SampleSet *sample_set,
105                          scalar_t *current_responses,
106                          scalar_t *loss_derivatives,
107                          int depth) {
108
109   if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
110     cerr << "You can not re-train a tree." << endl;
111     abort();
112   }
113
114   int nb_samples = sample_set->nb_samples();
115
116   int nb_pos = 0, nb_neg = 0;
117   for(int s = 0; s < sample_set->nb_samples(); s++) {
118     if(sample_set->label(s) > 0) nb_pos++;
119     else if(sample_set->label(s) < 0) nb_neg++;
120   }
121
122   (*global.log_stream) << "Training tree" << endl;
123   (*global.log_stream) << "  nb_samples " << nb_samples << endl;
124   (*global.log_stream) << "  depth " << depth << endl;
125   (*global.log_stream) << "  nb_pos = " << nb_pos << endl;
126   (*global.log_stream) << "  nb_neg = " << nb_neg << endl;
127
128   if(depth >= global.tree_depth_max)
129     (*global.log_stream) << "  Maximum depth reached." << endl;
130   if(nb_pos < min_nb_samples_for_split)
131     (*global.log_stream) << "  Not enough positive samples." << endl;
132   if(nb_neg < min_nb_samples_for_split)
133     (*global.log_stream) << "  Not enough negative samples." << endl;
134
135   if(depth < global.tree_depth_max &&
136      nb_pos >= min_nb_samples_for_split &&
137      nb_neg >= min_nb_samples_for_split) {
138
139     pick_best_split(sample_set, loss_derivatives);
140
141     if(_feature_index >= 0) {
142       int indexes[nb_samples];
143       scalar_t *parted_current_responses = new scalar_t[nb_samples];
144       scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
145
146       int nb_lesser = 0, nb_greater = 0;
147       int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
148
149       for(int s = 0; s < nb_samples; s++) {
150         if(sample_set->feature_value(s, _feature_index) < _threshold) {
151           indexes[nb_lesser] = s;
152           parted_current_responses[nb_lesser] = current_responses[s];
153           parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
154
155           if(sample_set->label(s) > 0)
156             nb_lesser_pos++;
157           else if(sample_set->label(s) < 0)
158             nb_lesser_neg++;
159
160           nb_lesser++;
161         } else {
162           nb_greater++;
163
164           indexes[nb_samples - nb_greater] = s;
165           parted_current_responses[nb_samples - nb_greater] = current_responses[s];
166           parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
167
168           if(sample_set->label(s) > 0)
169             nb_greater_pos++;
170           else if(sample_set->label(s) < 0)
171             nb_greater_neg++;
172         }
173       }
174
175       if((nb_lesser_pos >= min_nb_samples_for_split ||
176           nb_lesser_neg >= min_nb_samples_for_split) &&
177          (nb_greater_pos >= min_nb_samples_for_split ||
178           nb_greater_neg >= min_nb_samples_for_split)) {
179
180         _subtree_lesser = new DecisionTree();
181
182         {
183           SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
184
185           _subtree_lesser->train(loss_machine,
186                                  &sub_sample_set,
187                                  parted_current_responses,
188                                  parted_loss_derivatives,
189                                  depth + 1);
190         }
191
192         _subtree_greater = new DecisionTree();
193
194         {
195           SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
196
197           _subtree_greater->train(loss_machine,
198                                   &sub_sample_set,
199                                   parted_current_responses + nb_lesser,
200                                   parted_loss_derivatives + nb_lesser,
201                                   depth + 1);
202         }
203       }
204
205       delete[] parted_current_responses;
206       delete[] parted_loss_derivatives;
207     } else {
208       (*global.log_stream) << "Could not find a feature for split." << endl;
209     }
210   }
211
212   if(!(_subtree_greater && _subtree_lesser)) {
213     scalar_t *tmp_responses = new scalar_t[nb_samples];
214     for(int s = 0; s < nb_samples; s++)
215       tmp_responses[s] = 1;
216
217     _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
218
219     const scalar_t max_weight = 10.0;
220
221     if(_weight > max_weight) {
222       _weight = max_weight;
223     } else if(_weight < - max_weight) {
224       _weight = - max_weight;
225     }
226
227     (*global.log_stream) << "  _weight " << _weight << endl;
228
229     delete[] tmp_responses;
230   }
231 }
232
233 void DecisionTree::train(LossMachine *loss_machine,
234                  SampleSet *sample_set,
235                  scalar_t *current_responses) {
236
237   scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
238
239   loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
240
241   train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
242
243   delete[] loss_derivatives;
244 }
245
246 //////////////////////////////////////////////////////////////////////
247
248 void DecisionTree::tag_used_features(bool *used) {
249   if(_subtree_lesser && _subtree_greater) {
250     used[_feature_index] = true;
251     _subtree_lesser->tag_used_features(used);
252     _subtree_greater->tag_used_features(used);
253   }
254 }
255
256 void DecisionTree::re_index_features(int *new_indexes) {
257   if(_subtree_lesser && _subtree_greater) {
258     _feature_index = new_indexes[_feature_index];
259     _subtree_lesser->re_index_features(new_indexes);
260     _subtree_greater->re_index_features(new_indexes);
261   }
262 }
263
264 //////////////////////////////////////////////////////////////////////
265
266 void DecisionTree::read(istream *is) {
267   if(_subtree_lesser || _subtree_greater) {
268     cerr << "You can not read in an existing tree." << endl;
269     abort();
270   }
271
272   read_var(is, &_feature_index);
273   read_var(is, &_threshold);
274   read_var(is, &_weight);
275
276   int split;
277   read_var(is, &split);
278
279   if(split) {
280     _subtree_lesser = new DecisionTree();
281     _subtree_lesser->read(is);
282     _subtree_greater = new DecisionTree();
283     _subtree_greater->read(is);
284   }
285 }
286
287 void DecisionTree::write(ostream *os) {
288
289   write_var(os, &_feature_index);
290   write_var(os, &_threshold);
291   write_var(os, &_weight);
292
293   int split;
294   if(_subtree_lesser && _subtree_greater) {
295     split = 1;
296     write_var(os, &split);
297     _subtree_lesser->write(os);
298     _subtree_greater->write(os);
299   } else {
300     split = 0;
301     write_var(os, &split);
302   }
303 }