Added README.md
[clueless-kmeans.git] / clueless-kmeans.cc
1 /*
2  *  clueless-kmeans is a variant of k-means which enforces balanced
3  *  distribution of classes in every cluster
4  *
5  *  Copyright (c) 2013 Idiap Research Institute, http://www.idiap.ch/
6  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
7  *
8  *  This file is part of clueless-kmeans.
9  *
10  *  clueless-kmeans is free software: you can redistribute it and/or
11  *  modify it under the terms of the GNU General Public License
12  *  version 3 as published by the Free Software Foundation.
13  *
14  *  clueless-kmeans is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *  General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
21  *
22  */
23
24 #include <iostream>
25 #include <fstream>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <float.h>
29 #include <glpk.h>
30
31 using namespace std;
32
33 #include "misc.h"
34 #include "arrays.h"
35 #include "sample_set.h"
36 #include "clusterer.h"
37
38 void generate_toy_problem(SampleSet *sample_set) {
39   int dim = 2;
40   int nb_points = 1000;
41
42   sample_set->resize(dim, nb_points);
43   sample_set->nb_classes = 2;
44
45   for(int n = 0; n < nb_points; n++) {
46     sample_set->labels[n] = int(drand48() * 2);
47     if(sample_set->labels[n] == 0) {
48       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.8;
49       sample_set->points[n][1] = - 0.6 + (2 * drand48()  - 1) * 0.4;
50     } else {
51       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.4;
52       sample_set->points[n][1] =   0.6 + (2 * drand48()  - 1) * 0.4;
53     }
54   }
55 }
56
57 int main(int argc, char **argv) {
58   SampleSet sample_set;
59   Clusterer clusterer;
60   int nb_clusters = 3;
61
62   generate_toy_problem(&sample_set);
63
64   {
65     ofstream out("points.dat");
66     for(int n = 0; n < sample_set.nb_points; n++) {
67       out << sample_set.labels[n];
68       for(int d = 0; d < sample_set.dim; d++) {
69         out << " " << sample_set.points[n][d];
70       }
71       out << endl;
72     }
73   }
74
75   int *associated_clusters = new int[sample_set.nb_points];
76
77   glp_term_out(0);
78
79   int mode;
80
81   if(argc == 2) {
82     if(strcmp(argv[1], "standard") == 0) {
83       mode = Clusterer::STANDARD_LP_ASSOCIATION;
84     } else if(strcmp(argv[1], "clueless") == 0) {
85       mode = Clusterer::UNINFORMATIVE_LP_ASSOCIATION;
86     } else if(strcmp(argv[1], "clueless-absolute") == 0) {
87       mode = Clusterer::UNINFORMATIVE_LP_ASSOCIATION_ABSOLUTE;
88     } else {
89       cerr << "Unknown association mode " << argv[1] << endl;
90       exit(EXIT_FAILURE);
91     }
92   } else {
93     cerr << "Usage: " << argv[0] << " standard|clueless|clueless-absolute" << endl;
94     exit(EXIT_FAILURE);
95   }
96
97   clusterer.train(mode,
98                   nb_clusters,
99                   sample_set.dim,
100                   sample_set.nb_points, sample_set.points,
101                   sample_set.nb_classes, sample_set.labels,
102                   associated_clusters);
103
104   {
105     ofstream out("associated_clusters.dat");
106     for(int n = 0; n < sample_set.nb_points; n++) {
107       out << associated_clusters[n];
108       for(int d = 0; d < sample_set.dim; d++) {
109         out << " " << sample_set.points[n][d];
110       }
111       out << endl;
112     }
113   }
114
115   {
116     ofstream out("clusters.dat");
117     for(int k = 0 ; k < clusterer._nb_clusters; k++) {
118       out << k;
119       for(int d = 0; d < sample_set.dim; d++) {
120         out << " " << clusterer._cluster_means[k][d];
121       }
122       for(int d = 0; d < sample_set.dim; d++) {
123         out << " " << 2 * sqrt(clusterer._cluster_var[k][d]);
124       }
125       out << endl;
126     }
127   }
128
129   delete[] associated_clusters;
130
131   glp_free_env(); // I do not want valgrind to complain
132 }