1c2b08d219cc39c5ab378af4ed9b9e7967528833
[svrt.git] / discrete_density.h
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #ifndef DISCRETE_DENSITY_H
26 #define DISCRETE_DENSITY_H
27
28 #include "misc.h"
29
30 class DiscreteDensityTree {
31   scalar_t _proba_tree0;
32   DiscreteDensityTree *_tree0, *_tree1;
33   int _value;
34 public:
35   DiscreteDensityTree(scalar_t *proba, int first_value, int nb_values);
36   ~DiscreteDensityTree();
37   int sample();
38 };
39
40 class DiscreteDensity {
41   DiscreteDensityTree *_sampling_tree;
42   int _nb_values;
43   scalar_t *_probas, *_log_probas;
44 public:
45   DiscreteDensity(int nb_values);
46   ~DiscreteDensity();
47
48   void set_non_normalized_proba(int n, scalar_t p);
49   void normalize();
50   scalar_t entropy();
51
52   scalar_t proba(int n);
53   scalar_t log_proba(int n);
54   int sample();
55 };
56
57 #endif