Update.
[universe.git] / approximer.cc
1
2 // Written and (C) by Francois Fleuret
3 // Contact <francois.fleuret@idiap.ch> for comments & bug reports
4
5 #include "approximer.h"
6
7 MappingApproximer::MappingApproximer(int max_nb_weak_learners) :
8   _max_nb_weak_learners(max_nb_weak_learners),
9   _nb_weak_learners(0),
10   _indexes(new int[_max_nb_weak_learners]),
11   _thresholds(new scalar_t[_max_nb_weak_learners]),
12   _weights(new scalar_t[_max_nb_weak_learners]),
13   _nb_samples(-1),
14   _input_sorted_index(0),
15   _outputs_on_samples(0) { }
16
17 MappingApproximer::~MappingApproximer() {
18   delete[] _indexes;
19   delete[] _thresholds;
20   delete[] _weights;
21   delete[] _outputs_on_samples;
22 }
23
24 void MappingApproximer::load(istream &is) {
25   is.read((char *) &_nb_weak_learners, sizeof(_nb_weak_learners));
26   if(_nb_weak_learners > _max_nb_weak_learners) {
27     cerr << "Number of weak learners missmatch." << endl;
28     exit(1);
29   }
30   is.read((char *) _indexes, _nb_weak_learners * sizeof(int));
31   is.read((char *) _thresholds, _nb_weak_learners * sizeof(scalar_t));
32   is.read((char *) _weights, _nb_weak_learners * sizeof(scalar_t));
33 }
34
35 void MappingApproximer::save(ostream &os) {
36   os.write((char *) &_nb_weak_learners, sizeof(_nb_weak_learners));
37   os.write((char *) _indexes, _nb_weak_learners * sizeof(int));
38   os.write((char *) _thresholds, _nb_weak_learners * sizeof(scalar_t));
39   os.write((char *) _weights, _nb_weak_learners * sizeof(scalar_t));
40 }
41
42 void MappingApproximer::set_learning_input(int input_size, int nb_samples,
43                                            scalar_t *input, scalar_t *sample_weights) {
44
45   _input_size = input_size;
46   _input = input;
47   _sample_weights = sample_weights;
48   if(_nb_samples != nb_samples ){
49     _nb_samples = nb_samples;
50     delete[] _outputs_on_samples;
51     delete[] _input_sorted_index;
52     _outputs_on_samples = new scalar_t[_nb_samples];
53     _input_sorted_index = new int[_input_size * _nb_samples];
54   }
55   for(int t = 0; t < _nb_samples; t++) _outputs_on_samples[t] = 0.0;
56
57   for(int n = 0; n < _input_size; n++) {
58     Couple couples[_nb_samples];
59     for(int s = 0; s < _nb_samples; s++) {
60       couples[s].index = s;
61       couples[s].value = _input[s * _input_size + n];
62     }
63     qsort(couples, _nb_samples, sizeof(Couple), compare_couple);
64     for(int s = 0; s < _nb_samples; s++)
65       _input_sorted_index[n * _nb_samples + s] = couples[s].index;
66   }
67
68 }
69
70 void MappingApproximer::learn_one_step(scalar_t *target) {
71   scalar_t delta[_nb_samples], s_delta = 0.0, s_weights = 0;
72
73   for(int s = 0; s < _nb_samples; s++) {
74     delta[s] = _outputs_on_samples[s] - target[s];
75     s_delta += _sample_weights[s] * delta[s];
76     s_weights += _sample_weights[s];
77   }
78
79   scalar_t best_z = 0, z, prev, val;
80   int *i;
81
82   for(int n = 0; n < _input_size; n++) {
83     z = s_delta;
84     i = _input_sorted_index + n * _nb_samples;
85     prev = _input[(*i) * _input_size + n];
86     for(int s = 1; s < _nb_samples; s++) {
87       z -= 2 * _sample_weights[*i] * delta[*i];
88       i++;
89       val = _input[(*i) * _input_size + n];
90       if(val > prev && abs(z) > abs(best_z)) {
91         _thresholds[_nb_weak_learners] = (val + prev)/2;
92         _indexes[_nb_weak_learners] = n;
93         _weights[_nb_weak_learners] = - z / s_weights;
94         best_z = z;
95       }
96       prev = val;
97     }
98   }
99
100   if(best_z == 0) return;
101
102   // Update the responses on the samples
103   for(int s = 0; s < _nb_samples; s++) {
104     if(_input[s * _input_size + _indexes[_nb_weak_learners]] >= _thresholds[_nb_weak_learners])
105       _outputs_on_samples[s] += _weights[_nb_weak_learners];
106     else
107       _outputs_on_samples[s] -= _weights[_nb_weak_learners];
108   }
109
110   _nb_weak_learners++;
111 }
112
113 scalar_t MappingApproximer::predict(scalar_t *input) {
114   scalar_t r = 0;
115   for(int w = 0; w < _nb_weak_learners; w++)
116     if(input[_indexes[w]] >= _thresholds[w])
117       r += _weights[w];
118     else
119       r -= _weights[w];
120   return r;
121 }
122
123 void test_approximer() {
124 //   const int nb_samples = 1000, nb_weak_learners = 100;
125 //   MappingApproximer approximer(nb_weak_learners);
126 //   scalar_t input[nb_samples], output[nb_samples], weight[nb_samples];
127 //   for(int n = 0; n < nb_samples; n++) {
128 //     input[n] = scalar_t(n * 2 * M_PI)/scalar_t(nb_samples);
129 //     output[n] = sin(input[n]);
130 //     weight[n] = (drand48() < 0.5) ? 1.0 : 0.0;
131 //   }
132 //   approximer.set_learning_input(1, nb_samples, input, weight);
133 //   for(int w = 0; w < nb_weak_learners; w++) {
134 //     approximer.learn_one_step(output);
135 //     scalar_t e = 0;
136 //     for(int n = 0; n < nb_samples; n++)
137 //       e += weight[n] * sq(output[n] - approximer._outputs_on_samples[n]);
138 //     cerr << w << " " << e << endl;
139 //   }
140 //   for(int n = 0; n < nb_samples; n++) {
141 //     cout << input[n] << " " << approximer._outputs_on_samples[n] << endl;
142 //   }
143
144   const int dim = 5, nb_samples = 1000, nb_weak_learners = 100;
145   MappingApproximer approximer(nb_weak_learners);
146   scalar_t input[nb_samples * dim], output[nb_samples], weight[nb_samples];
147   for(int n = 0; n < nb_samples; n++) {
148     scalar_t s = 0;
149     for(int d = 0; d < dim; d++) {
150       input[n * dim + d] = drand48();
151       s += (d+1) * input[n * dim + d];
152     }
153     output[n] = s;
154     weight[n] = (drand48() < 0.5) ? 1.0 : 0.0;
155   }
156   approximer.set_learning_input(dim, nb_samples, input, weight);
157   for(int w = 0; w < nb_weak_learners; w++) {
158     approximer.learn_one_step(output);
159     scalar_t e = 0;
160     for(int n = 0; n < nb_samples; n++)
161       e += weight[n] * sq(output[n] - approximer._outputs_on_samples[n]);
162     cerr << w << " " << e << endl;
163   }
164   for(int n = 0; n < nb_samples; n++) {
165     cout << output[n] << " " << approximer._outputs_on_samples[n] << endl;
166   }
167 }