X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=parsing.cc;fp=parsing.cc;h=9d0060d06ca08c2a3ca786db2584c1452a6719c6;hb=d922ad61d35e9a6996730bec24b16f8bf7bc426c;hp=0000000000000000000000000000000000000000;hpb=3bb118f5a9462d02ff7d99ef28ecc0d7e23529f9;p=folded-ctf.git diff --git a/parsing.cc b/parsing.cc new file mode 100644 index 0000000..9d0060d --- /dev/null +++ b/parsing.cc @@ -0,0 +1,186 @@ + +/////////////////////////////////////////////////////////////////////////// +// This program is free software: you can redistribute it and/or modify // +// it under the terms of the version 3 of the GNU General Public License // +// as published by the Free Software Foundation. // +// // +// This program is distributed in the hope that it will be useful, but // +// WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // +// General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program. If not, see . // +// // +// Written by Francois Fleuret, (C) IDIAP // +// Contact for comments & bug reports // +/////////////////////////////////////////////////////////////////////////// + +#include "parsing.h" +#include "fusion_sort.h" + +Parsing::Parsing(LabelledImagePool *image_pool, + PoseCellHierarchy *hierarchy, + scalar_t proportion_negative_cells, + int image_index) { + + _image_pool = image_pool; + _image_index = image_index; + + PoseCellSet cell_set; + LabelledImage *image; + + image = _image_pool->grab_image(_image_index); + + hierarchy->add_root_cells(image, &cell_set); + + int *kept = new int[cell_set.nb_cells()]; + + _nb_cells = 0; + + for(int c = 0; c < cell_set.nb_cells(); c++) { + int l = image->pose_cell_label(cell_set.get_cell(c)); + kept[c] = (l > 0) || (l < 0 && drand48() < proportion_negative_cells); + if(kept[c]) _nb_cells++; + } + + _cells = new PoseCell[_nb_cells]; + _responses = new scalar_t[_nb_cells]; + _labels = new int[_nb_cells]; + _nb_positives = 0; + _nb_negatives = 0; + + int d = 0; + for(int c = 0; c < cell_set.nb_cells(); c++) { + if(kept[c]) { + _cells[d] = *(cell_set.get_cell(c)); + _labels[d] = image->pose_cell_label(&_cells[d]); + _responses[d] = 0; + if(_labels[d] < 0) { + _nb_negatives++; + } else if(_labels[d] > 0) { + _nb_positives++; + } + d++; + } + } + + delete[] kept; + + _image_pool->release_image(_image_index); +} + +Parsing::~Parsing() { + delete[] _cells; + delete[] _responses; + delete[] _labels; +} + +void Parsing::down_one_level(PoseCellHierarchy *hierarchy, + int level, int *sample_nb_occurences, scalar_t *sample_responses) { + PoseCellSet cell_set; + LabelledImage *image; + + int new_nb_cells = 0; + for(int c = 0; c < _nb_cells; c++) { + new_nb_cells += sample_nb_occurences[c]; + } + + PoseCell *new_cells = new PoseCell[new_nb_cells]; + scalar_t *new_responses = new scalar_t[new_nb_cells]; + int *new_labels = new int[new_nb_cells]; + + image = _image_pool->grab_image(_image_index); + int b = 0; + + for(int c = 0; c < _nb_cells; c++) { + + if(sample_nb_occurences[c] > 0) { + + cell_set.erase_content(); + hierarchy->add_subcells(level, _cells + c, &cell_set); + + if(_labels[c] > 0) { + ASSERT(sample_nb_occurences[c] == 1); + int e = -1; + for(int d = 0; d < cell_set.nb_cells(); d++) { + if(image->pose_cell_label(cell_set.get_cell(d)) > 0) { + ASSERT(e < 0); + e = d; + } + } + ASSERT(e >= 0); + ASSERT(b < new_nb_cells); + new_cells[b] = *(cell_set.get_cell(e)); + new_responses[b] = sample_responses[c]; + new_labels[b] = 1; + b++; + } + + else if(_labels[c] < 0) { + for(int d = 0; d < sample_nb_occurences[c]; d++) { + ASSERT(b < new_nb_cells); + new_cells[b] = *(cell_set.get_cell(int(drand48() * cell_set.nb_cells()))); + new_responses[b] = sample_responses[c]; + new_labels[b] = -1; + b++; + } + } + + else { + cerr << "INCONSISTENCY" << endl; + abort(); + } + } + } + + ASSERT(b == new_nb_cells); + + _image_pool->release_image(_image_index); + + delete[] _cells; + delete[] _labels; + delete[] _responses; + _nb_cells = new_nb_cells; + _cells = new_cells; + _labels = new_labels; + _responses = new_responses; +} + +void Parsing::update_cell_responses(PiFeatureFamily *pi_feature_family, + Classifier *classifier) { + LabelledImage *image; + + image = _image_pool->grab_image(_image_index); + image->compute_rich_structure(); + + SampleSet *samples = new SampleSet(pi_feature_family->nb_features(), 1); + + for(int c = 0; c < _nb_cells; c++) { + samples->set_sample(0, pi_feature_family, image, &_cells[c], 0); + _responses[c] += classifier->response(samples, 0); + ASSERT(!isnan(_responses[c])); + } + + _image_pool->release_image(_image_index); + delete samples; +} + +void Parsing::collect_samples(SampleSet *samples, + PiFeatureFamily *pi_feature_family, + int s, + int *to_collect) { + LabelledImage *image; + + image = _image_pool->grab_image(_image_index); + image->compute_rich_structure(); + + for(int c = 0; c < _nb_cells; c++) { + if(to_collect[c]) { + samples->set_sample(s, pi_feature_family, image, &_cells[c], _labels[c]); + s++; + } + } + + _image_pool->release_image(_image_index); +}