2 ////////////////////////////////////////////////////////////////////////////////
3 // This program is free software; you can redistribute it and/or //
4 // modify it under the terms of the GNU General Public License //
5 // version 2 as published by the Free Software Foundation. //
7 // This program is distributed in the hope that it will be useful, but //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
10 // General Public License for more details. //
12 // Written and (C) by François Fleuret //
13 // Contact <francois.fleuret@epfl.ch> for comments & bug reports //
14 ////////////////////////////////////////////////////////////////////////////////
16 #include "intelligence.h"
18 Intelligence::Intelligence(Map *input,
19 Manipulator *manipulator,
21 int nb_weak_learners) :
22 _nb_actions(manipulator->nb_actions()),
24 _manipulator(manipulator),
25 _max_memory_tick(max_memory_tick),
27 _memory(new scalar_t[_max_memory_tick * _input->nb_parameters]),
28 _rewards(new scalar_t[_max_memory_tick]),
29 _actions(new int[_max_memory_tick]),
30 _q_predictors(new MappingApproximer *[manipulator->nb_actions()]),
31 _nb_weak_learners(nb_weak_learners) {
32 for(int a = 0; a < _nb_actions; a++)
33 _q_predictors[a] = new MappingApproximer(_nb_weak_learners);
36 Intelligence::~Intelligence() {
37 for(int a = 0; a < _nb_actions; a++) delete _q_predictors[a];
38 delete[] _q_predictors;
44 void Intelligence::load(istream &is) {
47 is.read((char *) &na, sizeof(int));
48 is.read((char *) &np, sizeof(int));
50 if(na != _nb_actions || np != _input->nb_parameters) {
51 cerr << "Missmatch between the number of actions or input map size and the saved memory." << endl;
55 is.read((char *) &_memory_tick, sizeof(int));
57 if(_memory_tick > _max_memory_tick) {
58 cerr << "Can not load, too large memory dump." << endl;
62 is.read((char *) _actions, sizeof(int) * _memory_tick);
63 is.read((char *) _rewards, sizeof(scalar_t) * _memory_tick);
64 is.read((char *) _memory, sizeof(scalar_t) * _input->nb_parameters * _memory_tick);
66 for(int a = 0; a < _nb_actions; a++) _q_predictors[a]->load(is);
69 void Intelligence::save(ostream &os) {
70 os.write((char *) &_nb_actions, sizeof(_nb_actions));
71 os.write((char *) &_input->nb_parameters, sizeof(int));
72 os.write((char *) &_memory_tick, sizeof(int));
73 os.write((char *) _actions, sizeof(int) * _memory_tick);
74 os.write((char *) _rewards, sizeof(scalar_t) * _memory_tick);
75 os.write((char *) _memory, sizeof(scalar_t) * _input->nb_parameters * _memory_tick);
77 for(int a = 0; a < _nb_actions; a++) _q_predictors[a]->save(os);
80 void Intelligence::update(int last_action, scalar_t last_reward) {
81 if(_memory_tick == _max_memory_tick) abort();
82 ASSERT(last_action >= 0 && last_action < _nb_actions, "Action number out of bounds.");
83 _actions[_memory_tick] = last_action;
84 _rewards[_memory_tick] = last_reward;
85 int k = _memory_tick * _input->nb_parameters;
86 for(int p = 0; p < _input->nb_parameters; p++) _memory[k++] = _input->parameters[p];
90 void Intelligence::save_memory(char *filename) {
91 ofstream out(filename);
94 cerr << "Can not save to " << filename << "." << endl;
98 out.write((char *) &_input->nb_parameters, sizeof(int));
99 out.write((char *) &_memory_tick, sizeof(int));
100 out.write((char *) _actions, sizeof(int) * _memory_tick);
101 out.write((char *) _rewards, sizeof(scalar_t) * _memory_tick);
102 out.write((char *) _memory, sizeof(scalar_t) * _input->nb_parameters * _memory_tick);
106 void Intelligence::load_memory(char *filename) {
107 ifstream in(filename);
110 cerr << "Can not load from " << filename << "." << endl;
115 in.read((char *) &np, sizeof(int));
116 in.read((char *) &_memory_tick, sizeof(int));
118 if(np != _input->nb_parameters) {
119 cerr << "Missmatch between the input map size and the saved memory." << endl;
123 if(_memory_tick > _max_memory_tick) {
124 cerr << "Can not load, too large memory dump." << endl;
128 in.read((char *) _actions, sizeof(int) * _memory_tick);
129 in.read((char *) _rewards, sizeof(scalar_t) * _memory_tick);
130 in.read((char *) _memory, sizeof(scalar_t) * _input->nb_parameters * _memory_tick);
133 void Intelligence::learn(scalar_t proportion_for_training) {
134 scalar_t **sample_weigths;
135 int nb_train_ticks = int(_memory_tick * proportion_for_training);
136 sample_weigths = new scalar_t *[_nb_actions];
137 for(int a = 0; a < _nb_actions; a++) {
138 sample_weigths[a] = new scalar_t[_memory_tick];
139 for(int t = 0; t < _memory_tick; t++)
140 if(_actions[t] == a && t < nb_train_ticks) sample_weigths[a][t] = 1.0;
141 else sample_weigths[a][t] = 0.0;
142 _q_predictors[a]->set_learning_input(_input->nb_parameters,
148 scalar_t target[_memory_tick];
149 for(int t = 0; t < _memory_tick - 1; t++) target[t] = _rewards[t];
151 for(int u = 0; u < _nb_weak_learners; u++) {
153 for(int t = 0; t < _memory_tick - 1; t++) {
155 for(int a = 0; a < _nb_actions; a++) {
156 u = _q_predictors[a]->_outputs_on_samples[t+1];
159 const scalar_t lambda = 0.0;
160 target[t] = lambda * s + _rewards[t];
163 for(int a = 0; a < _nb_actions; a++) _q_predictors[a]->learn_one_step(target);
166 scalar_t e_train[_nb_actions];
167 for(int a = 0; a < _nb_actions; a++) e_train[a] = 0;
168 for(int t = 0; t < nb_train_ticks; t++)
169 e_train[_actions[t]] += sq(_q_predictors[_actions[t]]->_outputs_on_samples[t] - target[t]);
170 cout << "ERROR_TRAIN " << u+1;
171 for(int a = 0; a < _nb_actions; a++) cout << " " << e_train[a];
176 scalar_t e_test[_nb_actions];
177 for(int a = 0; a < _nb_actions; a++) e_test[a] = 0;
178 for(int t = nb_train_ticks; t < _memory_tick; t++)
179 e_test[_actions[t]] += sq(_q_predictors[_actions[t]]->_outputs_on_samples[t] - target[t]);
180 cout << "ERROR_TEST " << u+1;
181 for(int a = 0; a < _nb_actions; a++) cout << " " << e_test[a];
186 for(int a = 0; a < _nb_actions; a++) delete[] sample_weigths[a];
187 delete[] sample_weigths;
190 int Intelligence::best_action() {
193 cout << "ACTION_SCORES";
194 for(int a = 0; a < _nb_actions; a++) {
195 q = _q_predictors[a]->predict(_input->parameters);
197 if(a == 0 || q > max_q) {