Added REAME.md
[svrt.git] / error_rates.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "error_rates.h"
26 #include "rgb_image.h"
27
28 scalar_t error_rate(Classifier *classifier, int nb_vignettes, Vignette *vignettes, int *labels) {
29   int e = 0;
30
31   for(int n = 0; n < nb_vignettes; n++) {
32     if(classifier->classify(&vignettes[n]) >= 0) {
33       if(labels[n] == 0) e++;
34     } else {
35       if(labels[n] == 1) e++;
36     }
37   }
38
39   return scalar_t(e)/scalar_t(nb_vignettes);
40 }
41
42 scalar_t test_error_rate(VignetteGenerator *generator, Classifier *classifier, long int nb_to_try) {
43   scalar_t e = 0;
44   Vignette vignette;
45   int label;
46
47   global.bar.init(&cout, nb_to_try);
48
49   for(long int k = 0; k < nb_to_try; k++) {
50     label = int(drand48() * 2);
51     generator->generate(label, &vignette);
52     if(classifier->classify(&vignette) >= 0) {
53       if(label == 0) e++;
54     } else {
55       if(label == 1) e++;
56     }
57     global.bar.refresh(&cout, k);
58   }
59   global.bar.finish(&cout);
60
61   return scalar_t(e)/scalar_t(nb_to_try);
62 }
63
64
65 void compute_response_mu_and_sigma(int nb_samples, Vignette *vignette, Classifier *classifier,
66                                    scalar_t *mu, scalar_t *sigma) {
67   scalar_t sum = 0, sum_sq = 0;
68   const int nb_pixels_to_switch = 1;
69   int changed_pixels[nb_pixels_to_switch];
70
71   for(int n = 0; n < nb_samples; n++) {
72     for(int p = 0; p < nb_pixels_to_switch; p++) {
73       changed_pixels[p] = int(drand48() * Vignette::width * Vignette::height);
74       vignette->content[changed_pixels[p]] = 255 - vignette->content[changed_pixels[p]];
75     }
76
77     scalar_t r = classifier->classify(vignette);
78     sum += r;
79     sum_sq += sq(r);
80
81     for(int p = 0; p < nb_pixels_to_switch; p++) {
82       vignette->content[changed_pixels[p]] = 255 - vignette->content[changed_pixels[p]];
83     }
84   }
85
86   *mu = sum_sq/scalar_t(nb_samples);
87   *sigma = *mu - sq(sum/scalar_t(nb_samples));
88 }