Added a function to plot err vs. threshold.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 24 Jul 2009 15:19:11 +0000 (17:19 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 24 Jul 2009 15:19:11 +0000 (17:19 +0200)
data-tool.cc

index 0a2e049..dc868a0 100644 (file)
@@ -92,6 +92,7 @@ void print_help_and_exit(int e) {
        << "  --help" << endl
        << "  --roc" << endl
        << "  --roc-surface" << endl
+       << "  --error" << endl
        << "  --normalize" << endl
        << "  --histo" << endl
        << "  --cumul" << endl
@@ -121,7 +122,7 @@ int main(int argc, char **argv) {
 
   int i = 1;
 
-  enum { UNKNOWN, ROC, ROC_SURFACE, HISTO, CUMUL, MISC } processing = UNKNOWN;
+  enum { UNKNOWN, ROC, ROC_SURFACE, ERROR, HISTO, CUMUL, MISC } processing = UNKNOWN;
 
   // Parsing the command line arguments ////////////////////////////////
 
@@ -141,6 +142,12 @@ int main(int argc, char **argv) {
       i++;
     }
 
+    else if(strcmp(argv[i], "--error") == 0) {
+      check_single_processing(processing == UNKNOWN);
+      processing = ERROR;
+      i++;
+    }
+
     else if(strcmp(argv[i], "--cumul") == 0) {
       check_single_processing(processing == UNKNOWN);
       processing = CUMUL;
@@ -252,6 +259,7 @@ int main(int argc, char **argv) {
 
   case ROC:
   case ROC_SURFACE:
+  case ERROR:
 
     {
       int nb_samples = 0, nb_samples_max = 1000;
@@ -309,7 +317,7 @@ int main(int argc, char **argv) {
                  << endl;
           }
         }
-      } else {
+      } else if(processing == ROC_SURFACE) {
         double surface = 0;
         double cx = double(nb_fp)/double(nb_rn), cy = 1 - double(nb_fn) / double(nb_rp);
         for(int n = 0; n < nb_samples - 1; n++) {
@@ -322,6 +330,16 @@ int main(int argc, char **argv) {
           }
         }
         cout << surface  << endl;
+      } else {
+        for(int n = 0; n < nb_samples - 1; n++) {
+          if(x[tmp[n].index] >= 0) nb_fn++;
+          else                     nb_fp--;
+          if(tmp[n].value < tmp[n+1].value) {
+            cout << (tmp[n].value + tmp[n+1].value)/2 << " "
+                 << double(nb_fp + nb_fn)/double(nb_rn + nb_rp) << " "
+                 << endl;
+          }
+        }
       }
 
       delete[] x; delete[] y;