Added a function to plot err vs. threshold.
[data-tool.git] / data-tool.cc
1
2 /*
3  *  data-tool is a command line tool to do simple statistical
4  *  processing on numerical data.
5  *
6  *  Copyright (c) 2009 Francois Fleuret
7  *  Written by Francois Fleuret <francois@fleuret.org>
8  *
9  *  This file is part of data-tool.
10  *
11  *  data-tool is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  data-tool is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with data-tool.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <cmath>
27 #include <stdlib.h>
28 #include <string.h>
29
30 using namespace std;
31
32 struct Couple {
33   int index;
34   double value;
35 };
36
37 int compare_couple(const void *a, const void *b) {
38   if(((Couple *) a)->value < ((Couple *) b)->value) return -1;
39   else if(((Couple *) a)->value > ((Couple *) b)->value) return 1;
40   else return 0;
41 }
42
43 double *inflate_array(double *x, int current_size, int new_size) {
44   double *xx = new double[new_size];
45   for(int n = 0; n < current_size; n++) xx[n] = x[n];
46   delete[] x;
47   return xx;
48 }
49
50 char *next_word(char *buffer, char *r, int buffer_size) {
51   char *s;
52   s = buffer;
53   if(r != NULL)
54     {
55       if(*r == '"') {
56         r++;
57         while((*r != '"') && (*r != '\0') &&
58               (s<buffer+buffer_size-1))
59           *s++ = *r++;
60         if(*r == '"') r++;
61       } else {
62         while((*r != '\r') && (*r != '\n') && (*r != '\0') &&
63               (*r != '\t') && (*r != ' ') && (*r != ',') &&
64               (s<buffer+buffer_size-1))
65           *s++ = *r++;
66       }
67
68       while((*r == ' ') || (*r == '\t') || (*r == ',')) r++;
69       if((*r == '\0') || (*r=='\r') || (*r=='\n')) r = NULL;
70     }
71   *s = '\0';
72   return r;
73 }
74
75 void check_opt(int argc, char **argv, int n_opt, int n, const char *help) {
76   if(n_opt + n >= argc) {
77     cerr << "ERROR: Missing argument for " << argv[n_opt] << ". Expecting " << help << "." << endl;
78     exit(1);
79   }
80 }
81
82 void print_help_and_exit(int e) {
83   cout << "Simple data processing tool. Written by Francois Fleuret." << endl
84        << endl
85        << "This application takes data from the standard input and prints" << endl
86        << "the result on the standard output. It expects either a list of" << endl
87        << "float values (to produce histograms, cumulative distribution functions" << endl
88        << "or the mean, variance, etc.) or a list of couples of values of the form" << endl
89        << "x y on each line (where the sign of x tells the class and y the parameter" << endl
90        << "value) to compute the ROC curve or the ROC curve surface.\n" << endl
91        << "The options are:" << endl
92        << "  --help" << endl
93        << "  --roc" << endl
94        << "  --roc-surface" << endl
95        << "  --error" << endl
96        << "  --normalize" << endl
97        << "  --histo" << endl
98        << "  --cumul" << endl
99        << "  --misc" << endl
100        << "  --auto-extrema" << endl
101        << "  --xbounds <float: xmin> <float: xmax>" << endl
102        << "  --ybounds <float: ymin> <float: ymax>" << endl
103        << "  --nb-bins <int: number of bins>" << endl;
104   exit(e);
105 }
106
107 void check_single_processing(bool unknown_processing) {
108   if(!unknown_processing) {
109     cerr << "ERROR: You can't do two different processings." << endl;
110     exit(1);
111   }
112 }
113
114 int main(int argc, char **argv) {
115   double xmin = 0, xmax = 1, ymin = 0, ymax = 1;
116   int nb_bins = 10;
117   const int buffer_size = 1024;
118
119   char line[buffer_size], token[buffer_size];
120   bool auto_extrema = false;
121   bool normalize = false;
122
123   int i = 1;
124
125   enum { UNKNOWN, ROC, ROC_SURFACE, ERROR, HISTO, CUMUL, MISC } processing = UNKNOWN;
126
127   // Parsing the command line arguments ////////////////////////////////
128
129   while(i < argc) {
130
131     if(argc == 1 || strcmp(argv[i], "--help") == 0) print_help_and_exit(0);
132
133     else if(strcmp(argv[i], "--roc") == 0) {
134       check_single_processing(processing == UNKNOWN);
135       processing = ROC;
136       i++;
137     }
138
139     else if(strcmp(argv[i], "--roc-surface") == 0) {
140       check_single_processing(processing == UNKNOWN);
141       processing = ROC_SURFACE;
142       i++;
143     }
144
145     else if(strcmp(argv[i], "--error") == 0) {
146       check_single_processing(processing == UNKNOWN);
147       processing = ERROR;
148       i++;
149     }
150
151     else if(strcmp(argv[i], "--cumul") == 0) {
152       check_single_processing(processing == UNKNOWN);
153       processing = CUMUL;
154       i++;
155     }
156
157     else if(strcmp(argv[i], "--normalize") == 0) {
158       normalize = true;
159       i++;
160     }
161
162     else if(strcmp(argv[i], "--histo") == 0) {
163       check_single_processing(processing == UNKNOWN);
164       processing = HISTO;
165       i++;
166     }
167
168     else if(strcmp(argv[i], "--misc") == 0) {
169       check_single_processing(processing == UNKNOWN);
170       processing = MISC;
171       i++;
172     }
173
174     else if(strcmp(argv[i], "--auto-extrema") == 0) {
175       auto_extrema = true;
176       i++;
177     }
178
179     else if(strcmp(argv[i], "--xbounds") == 0) {
180       check_opt(argc, argv, i, 2, "<float: xmin> <float: xmax>");
181       xmin = atof(argv[i+1]);
182       xmax = atof(argv[i+2]);
183       if(xmin >= xmax) {
184         cerr << "ERROR: Incorrect bounds." << endl;
185         exit(1);
186       }
187       i += 3;
188     }
189
190     else if(strcmp(argv[i], "--ybounds") == 0) {
191       check_opt(argc, argv, i, 2, "<float: ymin> <float: ymax>");
192       ymin = atof(argv[i+1]);
193       ymax = atof(argv[i+2]);
194       if(ymin >= ymax) {
195         cerr << "ERROR: Incorrect bounds." << endl;
196         exit(1);
197       }
198       i += 3;
199     }
200
201     else if(strcmp(argv[i], "--nb-bins") == 0) {
202       check_opt(argc, argv, i, 1, "<int: number of bins>");
203       nb_bins = atoi(argv[i+1]);
204       if(nb_bins < 1) {
205         cerr << "ERROR: Incorrect number of bins." << endl;
206         exit(1);
207       }
208       i += 2;
209     }
210
211     else {
212       cerr << "ERROR: Unknown option " << argv[i]  << endl;
213       print_help_and_exit(1);
214     }
215   }
216
217   // Processing the data ///////////////////////////////////////////////
218
219   switch(processing) {
220
221   case CUMUL:
222
223     {
224       int nb_samples = 0, nb_samples_max = 50000;
225       double *x = new double[nb_samples_max];
226
227       while(!cin.eof()) {
228         if(nb_samples == nb_samples_max) {
229           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
230           nb_samples_max = 2 * nb_samples_max;
231         }
232
233         cin.getline(line, buffer_size);
234
235         if(line[0]) {
236           char *s = line;
237           s = next_word(token, s, buffer_size);
238           x[nb_samples] = atof(token);
239           nb_samples++;
240         }
241       }
242
243       Couple tmp[nb_samples];
244       for(int n = 0; n < nb_samples; n++) {
245         tmp[n].index = n;
246         tmp[n].value = x[n];
247       }
248
249       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
250
251       for(int n = 0; n < nb_samples; n++)
252         cout << tmp[n].value << " " << double(n)/double(nb_samples)  << endl;
253
254       delete[] x;
255
256     }
257
258     break;
259
260   case ROC:
261   case ROC_SURFACE:
262   case ERROR:
263
264     {
265       int nb_samples = 0, nb_samples_max = 1000;
266       double *x = new double[nb_samples_max], *y = new double[nb_samples_max];
267
268       while(!cin.eof()) {
269         if(nb_samples == nb_samples_max) {
270           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
271           y = inflate_array(y, nb_samples_max, 2 * nb_samples_max);
272           nb_samples_max = 2 * nb_samples_max;
273         }
274
275         cin.getline(line, buffer_size);
276
277         if(line[0]) {
278           char *s = line;
279           s = next_word(token, s, buffer_size);
280           x[nb_samples] = atof(token);
281           s = next_word(token, s, buffer_size);
282           y[nb_samples] = atof(token);
283           nb_samples++;
284         }
285       }
286
287       Couple tmp[nb_samples];
288       int nb_rn = 0, nb_rp = 0, nb_fp = 0, nb_fn = 0;
289
290       bool binary = true;
291       for(int n = 0; binary && n < nb_samples; n++) binary &= (x[n] == 0 || x[n] == 1);
292       if(binary) {
293         cerr << "WARNING: your classes are binary, I process them accordingly." << endl;
294         for(int n = 0; n < nb_samples; n++) x[n] = 2 * x[n] - 1;
295       }
296
297       for(int n = 0; n < nb_samples; n++) {
298         tmp[n].index = n;
299         tmp[n].value = y[n];
300         if(x[n] >= 0) nb_rp++;
301         else { nb_rn++; nb_fp++; }
302       }
303
304       if(nb_rp == 0) cerr << "WARNING: No true positive." << endl;
305       if(nb_rn == 0) cerr << "WARNING: No true negative." << endl;
306
307       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
308
309       if(processing == ROC) {
310         for(int n = 0; n < nb_samples - 1; n++) {
311           if(x[tmp[n].index] >= 0) nb_fn++;
312           else                     nb_fp--;
313           if(tmp[n].value < tmp[n+1].value) {
314             cout << double(nb_fp)/double(nb_rn) << " "
315                  << 1 - double(nb_fn) / double(nb_rp) << " "
316                  << (tmp[n].value + tmp[n+1].value)/2 << " "
317                  << endl;
318           }
319         }
320       } else if(processing == ROC_SURFACE) {
321         double surface = 0;
322         double cx = double(nb_fp)/double(nb_rn), cy = 1 - double(nb_fn) / double(nb_rp);
323         for(int n = 0; n < nb_samples - 1; n++) {
324           if(x[tmp[n].index] >= 0) nb_fn++;
325           else                     nb_fp--;
326           if(tmp[n].value < tmp[n+1].value) {
327             double ncx = double(nb_fp)/double(nb_rn), ncy = 1 - double(nb_fn) / double(nb_rp);
328             surface += (cx - ncx) * cy;
329             cx = ncx; cy = ncy;
330           }
331         }
332         cout << surface  << endl;
333       } else {
334         for(int n = 0; n < nb_samples - 1; n++) {
335           if(x[tmp[n].index] >= 0) nb_fn++;
336           else                     nb_fp--;
337           if(tmp[n].value < tmp[n+1].value) {
338             cout << (tmp[n].value + tmp[n+1].value)/2 << " "
339                  << double(nb_fp + nb_fn)/double(nb_rn + nb_rp) << " "
340                  << endl;
341           }
342         }
343       }
344
345       delete[] x; delete[] y;
346
347     }
348
349     break;
350
351   case HISTO:
352
353     {
354       int nb_samples = 0, nb_samples_max = 1000;
355       double *x = new double[nb_samples_max];
356
357       while(!cin.eof()) {
358         if(nb_samples == nb_samples_max) {
359           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
360           nb_samples_max = 2 * nb_samples_max;
361         }
362
363         cin.getline(line, buffer_size);
364
365         if(line[0]) {
366           char *s = line;
367           s = next_word(token, s, buffer_size);
368           x[nb_samples] = atof(token);
369           if(auto_extrema) {
370             if(nb_samples == 0 || x[nb_samples] > xmax) xmax = x[nb_samples];
371             if(nb_samples == 0 || x[nb_samples] < xmin) xmin = x[nb_samples];
372           }
373           nb_samples++;
374         }
375       }
376
377       int nb[nb_bins];
378       for(int n = 0; n < nb_bins; n++) nb[n] = 0;
379
380       int nb_total = 0;
381       for(int s = 0; s < nb_samples; s++) {
382         int n = int((x[s] - xmin)/(xmax - xmin) * nb_bins);
383         if(n >= 0 && n < nb_bins) nb[n]++;
384         else {
385           cerr << "WARNING: value " << x[s] << " is out of histogram." << endl;
386         }
387         nb_total++;
388       }
389
390       if(normalize) {
391         for(int n = 0; n < nb_bins; n++)
392           cout << xmin + ((xmax - xmin) * n) / double(nb_bins) << " "
393                << (nb[n] / double(nb_total))/((xmax - xmin) / double(nb_bins))  << endl;
394       } else {
395         for(int n = 0; n < nb_bins; n++)
396           cout << xmin + ((xmax - xmin) * n) / double(nb_bins) << " "
397                << nb[n] / double(nb_total)  << endl;
398       }
399     }
400
401     break;
402
403   case MISC:
404
405     {
406       int nb_samples = 0, nb_samples_max = 1000;
407       double *x = new double[nb_samples_max];
408       int nb = 0;
409       double min = 0, max = 0;
410       double sum = 0, sumsq = 0;
411
412       while(!cin.eof()) {
413         if(nb_samples == nb_samples_max) {
414           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
415           nb_samples_max = 2 * nb_samples_max;
416         }
417
418         cin.getline(line, buffer_size);
419         char *s = line;
420         if(line[0]) {
421           s = next_word(token, s, buffer_size);
422           x[nb_samples] = atof(token);
423           nb_samples++;
424           double x = atof(token);
425           if(nb == 0 || x > max) max = x;
426           if(nb == 0 || x < min) min = x;
427           sum += x;
428           sumsq += x*x;
429           nb++;
430         }
431       }
432
433       Couple tmp[nb_samples];
434       for(int n = 0; n < nb_samples; n++) {
435         tmp[n].index = n;
436         tmp[n].value = x[n];
437       }
438
439       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
440
441       delete[] x;
442
443       double mu = sum / double(nb);
444       double sigma = (sumsq - sum * mu) / double(nb - 1);
445       double stdd = sqrt(sigma);
446
447       cout << "MIN " << min
448            << " MAX " << max
449            << " MU " << mu
450            << " SIGMA " << sigma
451            << " STDD " << stdd
452            << " SUM " << sum
453            << " MEDIAN " << tmp[nb_samples/2].value
454            << " QUANTILE0.1 " << tmp[int(nb_samples * 0.1)].value
455            << " QUANTILE0.9 " << tmp[int(nb_samples * 0.9)].value
456            << endl;
457
458     }
459
460     break;
461
462   default:
463     cerr << "ERROR: You must choose a processing type." << endl;
464     exit(1);
465   }
466
467 }