696477a30f642f08b50802c86081cf79ff9d3497
[folded-ctf.git] / parsing_pool.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "parsing_pool.h"
20 #include "tools.h"
21
22 ParsingPool::ParsingPool(LabelledImagePool *image_pool, PoseCellHierarchy *hierarchy, scalar_t proportion_negative_cells) {
23   _nb_images = image_pool->nb_images();
24   _parsings = new Parsing *[_nb_images];
25
26   _nb_cells = 0;
27   _nb_positive_cells = 0;
28   _nb_negative_cells = 0;
29   for(int i = 0; i < _nb_images; i++) {
30     _parsings[i] = new Parsing(image_pool, hierarchy, proportion_negative_cells, i);
31     _nb_cells += _parsings[i]->nb_cells();
32     _nb_positive_cells += _parsings[i]->nb_positive_cells();
33     _nb_negative_cells += _parsings[i]->nb_negative_cells();
34   }
35   (*global.log_stream) << "ParsingPool initialized" << endl;
36   (*global.log_stream) << "  _nb_cells = " << _nb_cells << endl;
37   (*global.log_stream) << "  _nb_positive_cells = " << _nb_positive_cells << endl;
38   (*global.log_stream) << "  _nb_negative_cells = " << _nb_negative_cells << endl;
39 }
40
41 ParsingPool::~ParsingPool() {
42   for(int i = 0; i < _nb_images; i++)
43     delete _parsings[i];
44   delete[] _parsings;
45 }
46
47 void ParsingPool::down_one_level(LossMachine *loss_machine, PoseCellHierarchy *hierarchy, int level) {
48   scalar_t *labels = new scalar_t[_nb_cells];
49   scalar_t *tmp_responses = new scalar_t[_nb_cells];
50
51   int c;
52
53   { ////////////////////////////////////////////////////////////////////
54     // Sanity check
55     scalar_t l = 0;
56     for(int i = 0; i < _nb_images; i++) {
57       for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
58         if(_parsings[i]->label(d) != 0) {
59           l += exp( - _parsings[i]->label(d) * _parsings[i]->response(d));
60         }
61       }
62     }
63     (*global.log_stream) << "* INITIAL LOSS IS " << l << endl;
64   } ////////////////////////////////////////////////////////////////////
65
66   // Put the negative samples with their current responses, and all
67   // others to 0
68
69   c = 0;
70   for(int i = 0; i < _nb_images; i++) {
71     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
72       if(_parsings[i]->label(d) < 0) {
73         labels[c] = -1;
74         tmp_responses[c] = _parsings[i]->response(d);
75       } else {
76         labels[c] = 0;
77         tmp_responses[c] = 0;
78       }
79       c++;
80     }
81   }
82
83   // Sub-sample among the negative ones
84
85   int *sample_nb_occurences = new int[_nb_cells];
86   scalar_t *sample_responses = new scalar_t[_nb_cells];
87
88   loss_machine->subsample(_nb_cells, labels, tmp_responses,
89                           _nb_negative_cells, sample_nb_occurences, sample_responses,
90                           1);
91   c = 0;
92   for(int i = 0; i < _nb_images; i++) {
93     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
94       if(_parsings[i]->label(d) > 0) {
95         sample_nb_occurences[c + d] = 1;
96         sample_responses[c + d] = _parsings[i]->response(d);
97       }
98     }
99
100     int d = c + _parsings[i]->nb_cells();
101
102     _parsings[i]->down_one_level(hierarchy, level, sample_nb_occurences + c, sample_responses + c);
103
104     c = d;
105   }
106
107   { ////////////////////////////////////////////////////////////////////
108     // Sanity check
109     scalar_t l = 0;
110     for(int i = 0; i < _nb_images; i++) {
111       for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
112         if(_parsings[i]->label(d) != 0) {
113           l += exp( - _parsings[i]->label(d) * _parsings[i]->response(d));
114         }
115       }
116     }
117     (*global.log_stream) << "* FINAL LOSS IS " << l << endl;
118   } ////////////////////////////////////////////////////////////////////
119
120   delete[] sample_responses;
121   delete[] sample_nb_occurences;
122 }
123
124 void ParsingPool::update_cell_responses(PiFeatureFamily *pi_feature_family,
125                                         Classifier *classifier) {
126   for(int i = 0; i < _nb_images; i++) {
127     _parsings[i]->update_cell_responses(pi_feature_family, classifier);
128   }
129 }
130
131 void ParsingPool::weighted_sampling(LossMachine *loss_machine,
132                                     PiFeatureFamily *pi_feature_family,
133                                     SampleSet *sample_set,
134                                     scalar_t *responses) {
135
136   int nb_negatives_to_sample = sample_set->nb_samples() - _nb_positive_cells;
137
138   ASSERT(nb_negatives_to_sample > 0);
139
140   scalar_t *labels = new scalar_t[_nb_cells];
141   scalar_t *tmp_responses = new scalar_t[_nb_cells];
142
143   int c, s;
144
145   // Put the negative samples with their current responses, and all
146   // others to 0
147
148   c = 0;
149   for(int i = 0; i < _nb_images; i++) {
150     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
151       if(_parsings[i]->label(d) < 0) {
152         labels[c] = -1;
153         tmp_responses[c] = _parsings[i]->response(d);
154       } else {
155         labels[c] = 0;
156         tmp_responses[c] = 0;
157       }
158       c++;
159     }
160   }
161
162   // Sub-sample among the negative ones
163
164   int *sample_nb_occurences = new int[_nb_cells];
165   scalar_t *sample_responses = new scalar_t[_nb_cells];
166
167   loss_machine->subsample(_nb_cells, labels, tmp_responses,
168                           nb_negatives_to_sample, sample_nb_occurences, sample_responses,
169                           0);
170
171   for(int k = 0; k < _nb_cells; k++) {
172     if(sample_nb_occurences[k] > 0) {
173       ASSERT(sample_nb_occurences[k] == 1);
174       labels[k] = -1.0;
175       tmp_responses[k] = sample_responses[k];
176     } else {
177       labels[k] = 0;
178     }
179   }
180
181   delete[] sample_responses;
182   delete[] sample_nb_occurences;
183
184   // Put the positive ones
185
186   c = 0;
187   for(int i = 0; i < _nb_images; i++) {
188     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
189       if(_parsings[i]->label(d) > 0) {
190         labels[c] = 1;
191         tmp_responses[c] = _parsings[i]->response(d);
192       }
193       c++;
194     }
195   }
196
197   // Here we have the responses for the sub-sampled in tmp_responses,
198   // and we have labels[n] set to zero for non-sampled samples
199
200   s = 0;
201   c = 0;
202
203 //   global.bar.init(&cout, _nb_images);
204
205   for(int i = 0; i < _nb_images; i++) {
206
207     int *to_collect = new int[_parsings[i]->nb_cells()];
208
209     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
210       to_collect[d] = (labels[c + d] != 0);
211     }
212
213     _parsings[i]->collect_samples(sample_set, pi_feature_family, s, to_collect);
214
215     for(int d = 0; d < _parsings[i]->nb_cells(); d++) {
216       if(to_collect[d]) {
217         responses[s++] = tmp_responses[c + d];
218       }
219     }
220
221     delete[] to_collect;
222
223     c += _parsings[i]->nb_cells();
224
225 //     global.bar.refresh(&cout, i);
226   }
227
228 //   global.bar.finish(&cout);
229
230   delete[] tmp_responses;
231   delete[] labels;
232 }
233
234 void ParsingPool::write_roc(ofstream *out) {
235   int nb_negatives = nb_negative_cells();
236   int nb_positives = nb_positive_cells();
237
238   scalar_t *pos_responses = new scalar_t[nb_positives];
239   scalar_t *neg_responses = new scalar_t[nb_negatives];
240   int np = 0, nn = 0;
241   for(int i = 0; i < _nb_images; i++) {
242     for(int c = 0; c < _parsings[i]->nb_cells(); c++) {
243       if(_parsings[i]->label(c) > 0)
244         pos_responses[np++] = _parsings[i]->response(c);
245       else if(_parsings[i]->label(c) < 0)
246         neg_responses[nn++] = _parsings[i]->response(c);
247     }
248   }
249
250   ASSERT(nn == nb_negatives && np == nb_positives);
251
252   print_roc_small_pos(out,
253                       nb_positives, pos_responses,
254                       nb_negatives, neg_responses,
255                       1.0);
256
257   delete[] pos_responses;
258   delete[] neg_responses;
259 }