2 * folded-ctf is an implementation of the folded hierarchy of
3 * classifiers for object detection, developed by Francois Fleuret
6 * Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7 * Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 * This file is part of folded-ctf.
11 * folded-ctf is free software: you can redistribute it and/or modify
12 * it under the terms of the GNU General Public License as published
13 * by the Free Software Foundation, either version 3 of the License,
14 * or (at your option) any later version.
16 * folded-ctf is distributed in the hope that it will be useful, but
17 * WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 * General Public License for more details.
21 * You should have received a copy of the GNU General Public License
22 * along with folded-ctf. If not, see <http://www.gnu.org/licenses/>.
26 #include "decision_tree.h"
27 #include "fusion_sort.h"
29 DecisionTree::DecisionTree() {
37 DecisionTree::~DecisionTree() {
39 delete _subtree_lesser;
41 delete _subtree_greater;
44 int DecisionTree::nb_leaves() {
45 if(_subtree_lesser ||_subtree_greater)
46 return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
51 int DecisionTree::depth() {
52 if(_subtree_lesser ||_subtree_greater)
53 return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
58 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
59 if(_subtree_lesser && _subtree_greater) {
60 if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
61 return _subtree_lesser->response(sample_set, n_sample);
63 return _subtree_greater->response(sample_set, n_sample);
69 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
71 int nb_samples = sample_set->nb_samples();
73 scalar_t *responses = new scalar_t[nb_samples];
74 int *indexes = new int[nb_samples];
75 int *sorted_indexes = new int[nb_samples];
77 scalar_t max_abs_sum = 0;
80 for(int f = 0; f < sample_set->nb_features(); f++) {
83 for(int s = 0; s < nb_samples; s++) {
85 responses[s] = sample_set->feature_value(s, f);
86 sum += loss_derivatives[s];
89 indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
91 int t, u = sorted_indexes[0];
92 for(int s = 0; s < nb_samples - 1; s++) {
94 u = sorted_indexes[s + 1];
95 sum -= 2 * loss_derivatives[t];
97 if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
98 max_abs_sum = abs(sum);
100 _threshold = (responses[t] + responses[u])/2;
106 delete[] sorted_indexes;
110 void DecisionTree::train(LossMachine *loss_machine,
111 SampleSet *sample_set,
112 scalar_t *current_responses,
113 scalar_t *loss_derivatives,
116 if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
117 cerr << "You can not re-train a tree." << endl;
121 int nb_samples = sample_set->nb_samples();
123 int nb_pos = 0, nb_neg = 0;
124 for(int s = 0; s < sample_set->nb_samples(); s++) {
125 if(sample_set->label(s) > 0) nb_pos++;
126 else if(sample_set->label(s) < 0) nb_neg++;
129 (*global.log_stream) << "Training tree" << endl;
130 (*global.log_stream) << " nb_samples " << nb_samples << endl;
131 (*global.log_stream) << " depth " << depth << endl;
132 (*global.log_stream) << " nb_pos = " << nb_pos << endl;
133 (*global.log_stream) << " nb_neg = " << nb_neg << endl;
135 if(depth >= global.tree_depth_max)
136 (*global.log_stream) << " Maximum depth reached." << endl;
137 if(nb_pos < min_nb_samples_for_split)
138 (*global.log_stream) << " Not enough positive samples." << endl;
139 if(nb_neg < min_nb_samples_for_split)
140 (*global.log_stream) << " Not enough negative samples." << endl;
142 if(depth < global.tree_depth_max &&
143 nb_pos >= min_nb_samples_for_split &&
144 nb_neg >= min_nb_samples_for_split) {
146 pick_best_split(sample_set, loss_derivatives);
148 if(_feature_index >= 0) {
149 int indexes[nb_samples];
150 scalar_t *parted_current_responses = new scalar_t[nb_samples];
151 scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
153 int nb_lesser = 0, nb_greater = 0;
154 int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
156 for(int s = 0; s < nb_samples; s++) {
157 if(sample_set->feature_value(s, _feature_index) < _threshold) {
158 indexes[nb_lesser] = s;
159 parted_current_responses[nb_lesser] = current_responses[s];
160 parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
162 if(sample_set->label(s) > 0)
164 else if(sample_set->label(s) < 0)
171 indexes[nb_samples - nb_greater] = s;
172 parted_current_responses[nb_samples - nb_greater] = current_responses[s];
173 parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
175 if(sample_set->label(s) > 0)
177 else if(sample_set->label(s) < 0)
182 if((nb_lesser_pos >= min_nb_samples_for_split ||
183 nb_lesser_neg >= min_nb_samples_for_split) &&
184 (nb_greater_pos >= min_nb_samples_for_split ||
185 nb_greater_neg >= min_nb_samples_for_split)) {
187 _subtree_lesser = new DecisionTree();
190 SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
192 _subtree_lesser->train(loss_machine,
194 parted_current_responses,
195 parted_loss_derivatives,
199 _subtree_greater = new DecisionTree();
202 SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
204 _subtree_greater->train(loss_machine,
206 parted_current_responses + nb_lesser,
207 parted_loss_derivatives + nb_lesser,
212 delete[] parted_current_responses;
213 delete[] parted_loss_derivatives;
215 (*global.log_stream) << "Could not find a feature for split." << endl;
219 if(!(_subtree_greater && _subtree_lesser)) {
220 scalar_t *tmp_responses = new scalar_t[nb_samples];
221 for(int s = 0; s < nb_samples; s++)
222 tmp_responses[s] = 1;
224 _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
226 const scalar_t max_weight = 10.0;
228 if(_weight > max_weight) {
229 _weight = max_weight;
230 } else if(_weight < - max_weight) {
231 _weight = - max_weight;
234 (*global.log_stream) << " _weight " << _weight << endl;
236 delete[] tmp_responses;
240 void DecisionTree::train(LossMachine *loss_machine,
241 SampleSet *sample_set,
242 scalar_t *current_responses) {
244 scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
246 loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
248 train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
250 delete[] loss_derivatives;
253 //////////////////////////////////////////////////////////////////////
255 void DecisionTree::tag_used_features(bool *used) {
256 if(_subtree_lesser && _subtree_greater) {
257 used[_feature_index] = true;
258 _subtree_lesser->tag_used_features(used);
259 _subtree_greater->tag_used_features(used);
263 void DecisionTree::re_index_features(int *new_indexes) {
264 if(_subtree_lesser && _subtree_greater) {
265 _feature_index = new_indexes[_feature_index];
266 _subtree_lesser->re_index_features(new_indexes);
267 _subtree_greater->re_index_features(new_indexes);
271 //////////////////////////////////////////////////////////////////////
273 void DecisionTree::read(istream *is) {
274 if(_subtree_lesser || _subtree_greater) {
275 cerr << "You can not read in an existing tree." << endl;
279 read_var(is, &_feature_index);
280 read_var(is, &_threshold);
281 read_var(is, &_weight);
284 read_var(is, &split);
287 _subtree_lesser = new DecisionTree();
288 _subtree_lesser->read(is);
289 _subtree_greater = new DecisionTree();
290 _subtree_greater->read(is);
294 void DecisionTree::write(ostream *os) {
296 write_var(os, &_feature_index);
297 write_var(os, &_threshold);
298 write_var(os, &_weight);
301 if(_subtree_lesser && _subtree_greater) {
303 write_var(os, &split);
304 _subtree_lesser->write(os);
305 _subtree_greater->write(os);
308 write_var(os, &split);