183218b5ef8ff2b56038e4b16878cff2b7c1a520
[universe.git] / map.cc
1
2 ////////////////////////////////////////////////////////////////////////////////
3 // This program is free software; you can redistribute it and/or              //
4 // modify it under the terms of the GNU General Public License                //
5 // version 2 as published by the Free Software Foundation.                    //
6 //                                                                            //
7 // This program is distributed in the hope that it will be useful, but        //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of                 //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU          //
10 // General Public License for more details.                                   //
11 //                                                                            //
12 // Written and (C) by François Fleuret                                        //
13 // Contact <francois.fleuret@epfl.ch> for comments & bug reports              //
14 ////////////////////////////////////////////////////////////////////////////////
15
16 #include "map.h"
17
18 Map::Map() { parameters = 0; }
19
20 Map::~Map() { delete[] parameters; }
21
22 void Map::init(int np) {
23   nb_parameters = np;
24   parameters = new scalar_t[nb_parameters];
25   for(int n = 0; n < nb_parameters; n++) parameters[n] = 0;
26 }
27
28 MapConcatener::MapConcatener(int nb_max_maps) : _nb_max_maps(nb_max_maps), _nb_maps(0),
29                                                 _maps(new Map *[_nb_max_maps]) { }
30
31 MapConcatener::~MapConcatener() {
32   delete[] _maps;
33 }
34
35 void MapConcatener::add_map(Map *map) {
36   if(_nb_maps >= _nb_max_maps) abort();
37   _maps[_nb_maps++] = map;
38 }
39
40 void MapConcatener::init() {
41   int s = 0;
42   for(int m = 0; m < _nb_maps; m++) s += _maps[m]->nb_parameters;
43   Map::init(s);
44 }
45
46 void MapConcatener::update_map() {
47   for(int m = 0; m < _nb_maps; m++) _maps[m]->update_map();
48   int k = 0;
49   for(int m = 0; m < _nb_maps; m++) for(int l = 0; l < _maps[m]->nb_parameters; l++)
50     parameters[k++] = _maps[m]->parameters[l];
51 }
52
53 //////////////////////////////////////////////////////////////////////
54
55 MapExpander::MapExpander(int nb_units) : _nb_units(nb_units),
56                                          _unit_weights(0),
57                                          _input(0) { }
58
59 MapExpander::~MapExpander() {
60   delete[] _unit_weights;
61   delete[] _state_switch;
62 }
63
64 void MapExpander::load(istream &is) {
65   is.read((char *) &_nb_units, sizeof(_nb_units));
66   int input_size;
67   is.read((char *) &input_size, sizeof(input_size));
68
69   if(input_size != _input->nb_parameters) {
70     cerr << "Loaded map expander size missmatch." << endl;
71     exit(1);
72   }
73
74   _unit_weights = new scalar_t[_nb_units * 2 * (_input->nb_parameters + 1)];
75   parameters = new scalar_t[_nb_units];
76   _state_switch = new scalar_t[_nb_units];
77
78   is.read((char *) _unit_weights, sizeof(scalar_t) * _nb_units * 2 * (_input->nb_parameters + 1));
79
80   for(int u = 0; u < _nb_units; u++) {
81     _state_switch[u] = 0.0;
82     parameters[u] = 0.0;
83   }
84 }
85
86 void MapExpander::save(ostream &os) {
87   os.write((char *) &_nb_units, sizeof(_nb_units));
88   os.write((char *) &_input->nb_parameters, sizeof(_input->nb_parameters));
89   os.write((char *) _unit_weights, sizeof(scalar_t) * _nb_units * 2 * (_input->nb_parameters + 1));
90 }
91
92 void MapExpander::set_input(Map *input) {
93   ASSERT(!_input, "You can not set the input of an expanding map twice.");
94   _input = input;
95 }
96
97 void MapExpander::init() {
98   ASSERT(!_unit_weights, "You can not initialize a MapExpander twice.");
99   _unit_weights = new scalar_t[_nb_units * 2 * (_input->nb_parameters + 1)];
100   _state_switch = new scalar_t[_nb_units];
101   for(int k = 0; k < _nb_units * 2 * (_input->nb_parameters + 1); k++)
102     _unit_weights[k] = 2 * (drand48() - 0.5);
103   Map::init(_nb_units);
104   update_map();
105 }
106
107 void MapExpander::update_map() {
108   ASSERT(_unit_weights, "You have to call MapExpander::init() before using it.");
109   _input->update_map();
110   scalar_t s1, s2;
111   int k = 0;
112   for(int u = 0; u < _nb_units; u++) {
113     s1 = _unit_weights[k++];
114     for(int p = 0; p < _input->nb_parameters; p++)
115       s1 += _unit_weights[k++] * _input->parameters[p];
116     s2 = _unit_weights[k++];
117     for(int p = 0; p < _input->nb_parameters; p++)
118       s2 += _unit_weights[k++] * _input->parameters[p];
119     parameters[u] = s1;
120     _state_switch[u] = s2;
121   }
122 }
123
124 //////////////////////////////////////////////////////////////////////