X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=folded-ctf.git;a=blobdiff_plain;f=loss_machine.cc;fp=loss_machine.cc;h=63a65cc208f67348358f492adb2b1e90cb745498;hp=c596de65bb850654297dd0849a0e62e81fbecfb5;hb=5b78a555f6c7ff20a71d0520db63bc43e69e1f41;hpb=a44d56305517dddd36027c2a84693f2785fd019a diff --git a/loss_machine.cc b/loss_machine.cc index c596de6..63a65cc 100644 --- a/loss_machine.cc +++ b/loss_machine.cc @@ -40,47 +40,6 @@ void LossMachine::get_loss_derivatives(SampleSet *samples, } break; - case LOSS_EV_REGULARIZED: - { - scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos; - scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg; - - for(int n = 0; n < samples->nb_samples(); n++) { - if(samples->label(n) > 0) { - sum_pos += responses[n]; - sum_sq_pos += sq(responses[n]); - nb_pos += 1.0; - } - else if(samples->label(n) < 0) { - sum_neg += responses[n]; - sum_sq_neg += sq(responses[n]); - nb_neg += 1.0; - } - } - - m_pos = sum_pos / nb_pos; - v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1)); - - scalar_t loss_pos = nb_pos * exp(v_pos/2 - m_pos); - - m_neg = sum_neg / nb_neg; - v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1)); - - scalar_t loss_neg = nb_neg * exp(v_neg/2 + m_neg); - - for(int n = 0; n < samples->nb_samples(); n++) { - if(samples->label(n) > 0) { - derivatives[n] = - ( - 1/nb_pos + (responses[n] - m_pos)/(nb_pos - 1)) * loss_pos; - } else if(samples->label(n) < 0) { - derivatives[n] = - ( 1/nb_neg + (responses[n] - m_neg)/(nb_neg - 1)) * loss_neg; - } - } - } - - break; - case LOSS_HINGE: { for(int n = 0; n < samples->nb_samples(); n++) {