X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=clusterer.h;h=ad0c58f910bb22bdb449e8b63d4d5bf4549834c5;hb=30a7eaeed7e34e69b62d2920e074a113a6e850fc;hp=88c168a488511fc1f2d5e41f916d5b04641f7004;hpb=2455f83ba251602d5e04640067094f09f03aaa3d;p=clueless-kmeans.git diff --git a/clusterer.h b/clusterer.h index 88c168a..ad0c58f 100644 --- a/clusterer.h +++ b/clusterer.h @@ -29,35 +29,56 @@ class Clusterer { public: + + enum { + STANDARD_ASSOCIATION, + STANDARD_LP_ASSOCIATION, + UNINFORMATIVE_LP_ASSOCIATION + }; + const static int max_nb_iterations = 10; const static scalar_t min_iteration_improvement = 0.999; + const static scalar_t min_cluster_variance = 0.01f; int _nb_clusters; int _dim; scalar_t **_cluster_means, **_cluster_var; + scalar_t distance_to_centroid(scalar_t *x, int k); + void initialize_clusters(int nb_points, scalar_t **points); + // Standard hard k-mean association + scalar_t baseline_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); + // Standard k-mean association implemented as an LP optimization + scalar_t baseline_lp_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); + // Association under the constraint that each cluster gets the same + // class proportions as the overall training set + scalar_t uninformative_lp_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); - void baseline_update_clusters(int nb_points, scalar_t **points, scalar_t **gamma); + void update_clusters(int nb_points, scalar_t **points, scalar_t **gamma); public: Clusterer(); ~Clusterer(); - void train(int nb_clusters, int dim, + + void train(int mode, + int nb_clusters, int dim, int nb_points, scalar_t **points, int nb_classes, int *labels, + // This last array returns for each sample to what + // cluster it was associated. It can be null. int *cluster_associations); int cluster(scalar_t *point);