64d5cede277bc7b1261a2e0da874616471460cf0
[universe.git] / map.cc
1
2 ////////////////////////////////////////////////////////////////////////////////
3 // This program is free software; you can redistribute it and/or              //
4 // modify it under the terms of the GNU General Public License                //
5 // version 2 as published by the Free Software Foundation.                    //
6 //                                                                            //
7 // This program is distributed in the hope that it will be useful, but        //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of                 //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU          //
10 // General Public License for more details.                                   //
11 //                                                                            //
12 // Written and (C) by François Fleuret                                        //
13 // Contact <francois.fleuret@epfl.ch> for comments & bug reports              //
14 ////////////////////////////////////////////////////////////////////////////////
15
16 // $Id: map.cc,v 1.4 2006-07-26 06:53:18 fleuret Exp $
17
18 #include "map.h"
19
20 Map::Map() { parameters = 0; }
21
22 Map::~Map() { delete[] parameters; }
23
24 void Map::init(int np) {
25   nb_parameters = np;
26   parameters = new scalar_t[nb_parameters];
27   for(int n = 0; n < nb_parameters; n++) parameters[n] = 0;
28 }
29
30 MapConcatener::MapConcatener(int nb_max_maps) : _nb_max_maps(nb_max_maps), _nb_maps(0),
31                                                 _maps(new Map *[_nb_max_maps]) { }
32
33 MapConcatener::~MapConcatener() {
34   delete[] _maps;
35 }
36
37 void MapConcatener::add_map(Map *map) {
38   if(_nb_maps >= _nb_max_maps) abort();
39   _maps[_nb_maps++] = map;
40 }
41
42 void MapConcatener::init() {
43   int s = 0;
44   for(int m = 0; m < _nb_maps; m++) s += _maps[m]->nb_parameters;
45   Map::init(s);
46 }
47
48 void MapConcatener::update_map() {
49   for(int m = 0; m < _nb_maps; m++) _maps[m]->update_map();
50   int k = 0;
51   for(int m = 0; m < _nb_maps; m++) for(int l = 0; l < _maps[m]->nb_parameters; l++)
52     parameters[k++] = _maps[m]->parameters[l];
53 }
54
55 //////////////////////////////////////////////////////////////////////
56
57 MapExpander::MapExpander(int nb_units) : _nb_units(nb_units),
58                                          _unit_weights(0),
59                                          _input(0) { }
60
61 MapExpander::~MapExpander() {
62   delete[] _unit_weights;
63   delete[] _state_switch;
64 }
65
66 void MapExpander::load(istream &is) {
67   is.read((char *) &_nb_units, sizeof(_nb_units));
68   int input_size;
69   is.read((char *) &input_size, sizeof(input_size));
70
71   if(input_size != _input->nb_parameters) {
72     cerr << "Loaded map expander size missmatch." << endl;
73     exit(1);
74   }
75
76   _unit_weights = new scalar_t[_nb_units * 2 * (_input->nb_parameters + 1)];
77   parameters = new scalar_t[_nb_units];
78   _state_switch = new scalar_t[_nb_units];
79
80   is.read((char *) _unit_weights, sizeof(scalar_t) * _nb_units * 2 * (_input->nb_parameters + 1));
81
82   for(int u = 0; u < _nb_units; u++) {
83     _state_switch[u] = 0.0;
84     parameters[u] = 0.0;
85   }
86 }
87
88 void MapExpander::save(ostream &os) {
89   os.write((char *) &_nb_units, sizeof(_nb_units));
90   os.write((char *) &_input->nb_parameters, sizeof(_input->nb_parameters));
91   os.write((char *) _unit_weights, sizeof(scalar_t) * _nb_units * 2 * (_input->nb_parameters + 1));
92 }
93
94 void MapExpander::set_input(Map *input) {
95   ASSERT(!_input, "You can not set the input of an expanding map twice.");
96   _input = input;
97 }
98
99 void MapExpander::init() {
100   ASSERT(!_unit_weights, "You can not initialize a MapExpander twice.");
101   _unit_weights = new scalar_t[_nb_units * 2 * (_input->nb_parameters + 1)];
102   _state_switch = new scalar_t[_nb_units];
103   for(int k = 0; k < _nb_units * 2 * (_input->nb_parameters + 1); k++)
104     _unit_weights[k] = 2 * (drand48() - 0.5);
105   Map::init(_nb_units);
106   update_map();
107 }
108
109 void MapExpander::update_map() {
110   ASSERT(_unit_weights, "You have to call MapExpander::init() before using it.");
111   _input->update_map();
112   scalar_t s1, s2;
113   int k = 0;
114   for(int u = 0; u < _nb_units; u++) {
115     s1 = _unit_weights[k++];
116     for(int p = 0; p < _input->nb_parameters; p++)
117       s1 += _unit_weights[k++] * _input->parameters[p];
118     s2 = _unit_weights[k++];
119     for(int p = 0; p < _input->nb_parameters; p++)
120       s2 += _unit_weights[k++] * _input->parameters[p];
121     parameters[u] = s1;
122     _state_switch[u] = s2;
123   }
124 }
125
126 //////////////////////////////////////////////////////////////////////