X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tools.cc;fp=tools.cc;h=1834cd851e1ab56987db89f0214454b6bd2c8ded;hb=d922ad61d35e9a6996730bec24b16f8bf7bc426c;hp=0000000000000000000000000000000000000000;hpb=3bb118f5a9462d02ff7d99ef28ecc0d7e23529f9;p=folded-ctf.git diff --git a/tools.cc b/tools.cc new file mode 100644 index 0000000..1834cd8 --- /dev/null +++ b/tools.cc @@ -0,0 +1,108 @@ + +/////////////////////////////////////////////////////////////////////////// +// This program is free software: you can redistribute it and/or modify // +// it under the terms of the version 3 of the GNU General Public License // +// as published by the Free Software Foundation. // +// // +// This program is distributed in the hope that it will be useful, but // +// WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // +// General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program. If not, see . // +// // +// Written by Francois Fleuret, (C) IDIAP // +// Contact for comments & bug reports // +/////////////////////////////////////////////////////////////////////////// + +#include "misc.h" +#include "tools.h" +#include "fusion_sort.h" + +scalar_t robust_sampling(int nb, scalar_t *weights, int nb_to_sample, int *sampled) { + ASSERT(nb > 0); + if(nb == 1) { + for(int k = 0; k < nb_to_sample; k++) sampled[k] = 0; + return weights[0]; + } else { + scalar_t *pair_weights = new scalar_t[(nb+1)/2]; + for(int k = 0; k < nb/2; k++) + pair_weights[k] = weights[2 * k] + weights[2 * k + 1]; + if(nb%2) + pair_weights[(nb+1)/2 - 1] = weights[nb-1]; + scalar_t result = robust_sampling((nb+1)/2, pair_weights, nb_to_sample, sampled); + for(int k = 0; k < nb_to_sample; k++) { + int s = sampled[k]; + // There is a bit of a trick for the isolated sample in the odd + // case. Since the corresponding pair weight is the same as the + // one sample alone, the test is always true and the isolated + // sample will be taken for sure. + if(drand48() * pair_weights[s] <= weights[2 * s]) + sampled[k] = 2 * s; + else + sampled[k] = 2 * s + 1; + } + delete[] pair_weights; + return result; + } +} + +void print_roc_small_pos(ostream *out, + int nb_pos, scalar_t *pos_responses, + int nb_neg, scalar_t *neg_responses, + scalar_t fas_factor) { + + scalar_t *sorted_pos_responses = new scalar_t[nb_pos]; + + fusion_sort(nb_pos, pos_responses, sorted_pos_responses); + + int *bins = new int[nb_pos + 1]; + for(int k = 0; k <= nb_pos; k++) bins[k] = 0; + + for(int k = 0; k < nb_neg; k++) { + scalar_t r = neg_responses[k]; + + if(r < sorted_pos_responses[0]) + bins[0]++; + + else if(r >= sorted_pos_responses[nb_pos - 1]) + bins[nb_pos]++; + + else { + int a = 0; + int b = nb_pos - 1; + int c = 0; + + while(a < b - 1) { + c = (a + b) / 2; + if(r < sorted_pos_responses[c]) + b = c; + else + a = c; + } + + // Beware of identical positive responses + while(c < nb_pos && r >= sorted_pos_responses[c]) + c++; + + bins[c]++; + } + } + + int s = nb_neg; + for(int k = 0; k < nb_pos; k++) { + s -= bins[k]; + if(k == 0 || sorted_pos_responses[k-1] < sorted_pos_responses[k]) { + (*out) << (scalar_t(s) / scalar_t(nb_neg)) * fas_factor + << " " + << scalar_t(nb_pos - k)/scalar_t(nb_pos) + << " " + << sorted_pos_responses[k] + << endl; + } + } + + delete[] bins; + delete[] sorted_pos_responses; +}