ab58f6df1d534b66cc1651be8fd1964fa04e2ee8
[svrt.git] / discrete_density.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "discrete_density.h"
26
27 DiscreteDensityTree::DiscreteDensityTree(scalar_t *proba, int first_value, int nb_values) {
28   if(nb_values > 1) {
29     _proba_tree0 = 0;
30     for(int n = 0; n < nb_values/2; n++) _proba_tree0 += proba[first_value + n];
31     scalar_t s = 0;
32     for(int n = 0; n < nb_values; n++) s += proba[first_value + n];
33     _proba_tree0 /= s;
34     _tree0 = new DiscreteDensityTree(proba, first_value, nb_values/2);
35     _tree1 = new DiscreteDensityTree(proba, first_value + nb_values/2, nb_values - (nb_values/2));
36   } else {
37     _tree0 = 0;
38     _tree1 = 0;
39     _value = first_value;
40   }
41 }
42
43 DiscreteDensityTree::~DiscreteDensityTree() {
44   if(_tree0) delete _tree0;
45   if(_tree1) delete _tree1;
46 }
47
48 int DiscreteDensityTree::sample() {
49   if(_tree0) {
50     if(drand48() < _proba_tree0)
51       return _tree0->sample();
52     else
53       return _tree1->sample();
54   } else return _value;
55 }
56
57 DiscreteDensity::DiscreteDensity(int nb_values) {
58   _nb_values = nb_values;
59   _probas = new scalar_t[_nb_values];
60   _log_probas = new scalar_t[_nb_values];
61   _sampling_tree = 0;
62 }
63
64 DiscreteDensity::~DiscreteDensity() {
65   delete[] _probas;
66   delete[] _log_probas;
67   delete _sampling_tree;
68 }
69
70 void DiscreteDensity::set_non_normalized_proba(int n, scalar_t p) {
71   _probas[n] = p;
72 }
73
74 void DiscreteDensity::normalize() {
75   scalar_t s = 0;
76   for(int k = 0; k < _nb_values; k++) {
77     s += _probas[k];
78   }
79   for(int k = 0; k < _nb_values; k++) {
80     _probas[k] /= s;
81     _log_probas[k] = log(_probas[k]);
82   }
83   delete _sampling_tree;
84   _sampling_tree = new DiscreteDensityTree(_probas, 0, _nb_values);
85 }
86
87 scalar_t DiscreteDensity::entropy() {
88   scalar_t h = 0;
89   for(int k = 0; k < _nb_values; k++) {
90     if(_probas[k] > 0) h += - _probas[k] * _log_probas[k]/log(2.0);
91   }
92   return h;
93 }
94
95 scalar_t DiscreteDensity::proba(int n) {
96   return _probas[n];
97 }
98
99 scalar_t DiscreteDensity::log_proba(int n) {
100   return _log_probas[n];
101 }
102
103 int DiscreteDensity::sample() {
104   return _sampling_tree->sample();
105 }