2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation. //
7 // This program is distributed in the hope that it will be useful, but //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
10 // General Public License for more details. //
12 // You should have received a copy of the GNU General Public License //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>. //
15 // Written by Francois Fleuret //
16 // (C) Idiap Research Institute //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports //
19 ///////////////////////////////////////////////////////////////////////////
21 #include "decision_tree.h"
22 #include "fusion_sort.h"
24 DecisionTree::DecisionTree() {
32 DecisionTree::~DecisionTree() {
34 delete _subtree_lesser;
36 delete _subtree_greater;
39 int DecisionTree::nb_leaves() {
40 if(_subtree_lesser ||_subtree_greater)
41 return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
46 int DecisionTree::depth() {
47 if(_subtree_lesser ||_subtree_greater)
48 return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
53 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
54 if(_subtree_lesser && _subtree_greater) {
55 if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
56 return _subtree_lesser->response(sample_set, n_sample);
58 return _subtree_greater->response(sample_set, n_sample);
64 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
66 int nb_samples = sample_set->nb_samples();
68 scalar_t *responses = new scalar_t[nb_samples];
69 int *indexes = new int[nb_samples];
70 int *sorted_indexes = new int[nb_samples];
72 scalar_t max_abs_sum = 0;
75 for(int f = 0; f < sample_set->nb_features(); f++) {
78 for(int s = 0; s < nb_samples; s++) {
80 responses[s] = sample_set->feature_value(s, f);
81 sum += loss_derivatives[s];
84 indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
86 int t, u = sorted_indexes[0];
87 for(int s = 0; s < nb_samples - 1; s++) {
89 u = sorted_indexes[s + 1];
90 sum -= 2 * loss_derivatives[t];
92 if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
93 max_abs_sum = abs(sum);
95 _threshold = (responses[t] + responses[u])/2;
101 delete[] sorted_indexes;
105 void DecisionTree::train(LossMachine *loss_machine,
106 SampleSet *sample_set,
107 scalar_t *current_responses,
108 scalar_t *loss_derivatives,
111 if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
112 cerr << "You can not re-train a tree." << endl;
116 int nb_samples = sample_set->nb_samples();
118 int nb_pos = 0, nb_neg = 0;
119 for(int s = 0; s < sample_set->nb_samples(); s++) {
120 if(sample_set->label(s) > 0) nb_pos++;
121 else if(sample_set->label(s) < 0) nb_neg++;
124 (*global.log_stream) << "Training tree" << endl;
125 (*global.log_stream) << " nb_samples " << nb_samples << endl;
126 (*global.log_stream) << " depth " << depth << endl;
127 (*global.log_stream) << " nb_pos = " << nb_pos << endl;
128 (*global.log_stream) << " nb_neg = " << nb_neg << endl;
130 if(depth >= global.tree_depth_max)
131 (*global.log_stream) << " Maximum depth reached." << endl;
132 if(nb_pos < min_nb_samples_for_split)
133 (*global.log_stream) << " Not enough positive samples." << endl;
134 if(nb_neg < min_nb_samples_for_split)
135 (*global.log_stream) << " Not enough negative samples." << endl;
137 if(depth < global.tree_depth_max &&
138 nb_pos >= min_nb_samples_for_split &&
139 nb_neg >= min_nb_samples_for_split) {
141 pick_best_split(sample_set, loss_derivatives);
143 if(_feature_index >= 0) {
144 int indexes[nb_samples];
145 scalar_t *parted_current_responses = new scalar_t[nb_samples];
146 scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
148 int nb_lesser = 0, nb_greater = 0;
149 int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
151 for(int s = 0; s < nb_samples; s++) {
152 if(sample_set->feature_value(s, _feature_index) < _threshold) {
153 indexes[nb_lesser] = s;
154 parted_current_responses[nb_lesser] = current_responses[s];
155 parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
157 if(sample_set->label(s) > 0)
159 else if(sample_set->label(s) < 0)
166 indexes[nb_samples - nb_greater] = s;
167 parted_current_responses[nb_samples - nb_greater] = current_responses[s];
168 parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
170 if(sample_set->label(s) > 0)
172 else if(sample_set->label(s) < 0)
177 if((nb_lesser_pos >= min_nb_samples_for_split ||
178 nb_lesser_neg >= min_nb_samples_for_split) &&
179 (nb_greater_pos >= min_nb_samples_for_split ||
180 nb_greater_neg >= min_nb_samples_for_split)) {
182 _subtree_lesser = new DecisionTree();
185 SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
187 _subtree_lesser->train(loss_machine,
189 parted_current_responses,
190 parted_loss_derivatives,
194 _subtree_greater = new DecisionTree();
197 SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
199 _subtree_greater->train(loss_machine,
201 parted_current_responses + nb_lesser,
202 parted_loss_derivatives + nb_lesser,
207 delete[] parted_current_responses;
208 delete[] parted_loss_derivatives;
210 (*global.log_stream) << "Could not find a feature for split." << endl;
214 if(!(_subtree_greater && _subtree_lesser)) {
215 scalar_t *tmp_responses = new scalar_t[nb_samples];
216 for(int s = 0; s < nb_samples; s++)
217 tmp_responses[s] = 1;
219 _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
221 const scalar_t max_weight = 10.0;
223 if(_weight > max_weight) {
224 _weight = max_weight;
225 } else if(_weight < - max_weight) {
226 _weight = - max_weight;
229 (*global.log_stream) << " _weight " << _weight << endl;
231 delete[] tmp_responses;
235 void DecisionTree::train(LossMachine *loss_machine,
236 SampleSet *sample_set,
237 scalar_t *current_responses) {
239 scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
241 loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
243 train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
245 delete[] loss_derivatives;
248 //////////////////////////////////////////////////////////////////////
250 void DecisionTree::tag_used_features(bool *used) {
251 if(_subtree_lesser && _subtree_greater) {
252 used[_feature_index] = true;
253 _subtree_lesser->tag_used_features(used);
254 _subtree_greater->tag_used_features(used);
258 void DecisionTree::re_index_features(int *new_indexes) {
259 if(_subtree_lesser && _subtree_greater) {
260 _feature_index = new_indexes[_feature_index];
261 _subtree_lesser->re_index_features(new_indexes);
262 _subtree_greater->re_index_features(new_indexes);
266 //////////////////////////////////////////////////////////////////////
268 void DecisionTree::read(istream *is) {
269 if(_subtree_lesser || _subtree_greater) {
270 cerr << "You can not read in an existing tree." << endl;
274 read_var(is, &_feature_index);
275 read_var(is, &_threshold);
276 read_var(is, &_weight);
279 read_var(is, &split);
282 _subtree_lesser = new DecisionTree();
283 _subtree_lesser->read(is);
284 _subtree_greater = new DecisionTree();
285 _subtree_greater->read(is);
289 void DecisionTree::write(ostream *os) {
291 write_var(os, &_feature_index);
292 write_var(os, &_threshold);
293 write_var(os, &_weight);
296 if(_subtree_lesser && _subtree_greater) {
298 write_var(os, &split);
299 _subtree_lesser->write(os);
300 _subtree_greater->write(os);
303 write_var(os, &split);