X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=loss_machine.h;fp=loss_machine.h;h=a293e8b71a9090bb2ed89fc1d31fafd7c819288d;hb=d922ad61d35e9a6996730bec24b16f8bf7bc426c;hp=0000000000000000000000000000000000000000;hpb=3bb118f5a9462d02ff7d99ef28ecc0d7e23529f9;p=folded-ctf.git diff --git a/loss_machine.h b/loss_machine.h new file mode 100644 index 0000000..a293e8b --- /dev/null +++ b/loss_machine.h @@ -0,0 +1,55 @@ + +/////////////////////////////////////////////////////////////////////////// +// This program is free software: you can redistribute it and/or modify // +// it under the terms of the version 3 of the GNU General Public License // +// as published by the Free Software Foundation. // +// // +// This program is distributed in the hope that it will be useful, but // +// WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // +// General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program. If not, see . // +// // +// Written by Francois Fleuret, (C) IDIAP // +// Contact for comments & bug reports // +/////////////////////////////////////////////////////////////////////////// + +#ifndef LOSS_MACHINE_H +#define LOSS_MACHINE_H + +#include "misc.h" +#include "sample_set.h" + +class LossMachine { + int _loss_type; + +public: + LossMachine(int loss_type); + + void get_loss_derivatives(SampleSet *samples, + scalar_t *responses, + scalar_t *derivatives); + + scalar_t loss(SampleSet *samples, scalar_t *responses); + + scalar_t optimal_weight(SampleSet *sample_set, + scalar_t *weak_learner_responses, + scalar_t *current_responses); + + // This method returns in sample_nb_occurences[k] the number of time + // the example k was sampled, and in sample_responses[k] the + // consistent response so that the overall loss remains the same. If + // allow_duplicates is set to 1, all samples will have an identical + // response (i.e. weight), but some may have more than one + // occurence. On the contrary, if allow_duplicates is 0, samples + // will all have only one occurence (or zero) but the responses may + // vary to account for the multiple sampling. + + void subsample(int nb, scalar_t *labels, scalar_t *responses, + int nb_to_sample, int *sample_nb_occurences, scalar_t *sample_responses, + int allow_duplicates); +}; + +#endif