Added #include <unistd.h> for nice().
[mlp.git] / images.cc
1 /*
2  *  mlp-mnist is an implementation of a multi-layer neural network.
3  *
4  *  Copyright (c) 2006 École Polytechnique Fédérale de Lausanne,
5  *  http://www.epfl.ch
6  *
7  *  Written by Francois Fleuret <francois@fleuret.org>
8  *
9  *  This file is part of mlp-mnist.
10  *
11  *  mlp-mnist is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  mlp-mnist is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with mlp-mnist.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "images.h"
26 #include <stdlib.h>
27
28 PixelMaps::PixelMaps(int size) : _nb_ref(0),_core(new unsigned char[size]) {}
29 PixelMaps::~PixelMaps() { delete[] _core; }
30 PixelMaps *PixelMaps::add_ref() { _nb_ref++; return this; }
31 void PixelMaps::del_ref() { _nb_ref--; if(_nb_ref == 0) delete this; }
32
33 const unsigned int mnist_pictures_magic = 0x00000803;
34 const unsigned int mnist_labels_magic = 0x00000801;
35
36 inline unsigned int read_high_endian_int(istream &is) {
37   unsigned int result;
38   char *s = (char *) &result;
39   char c;
40   is.read(s, sizeof(result));
41   c = s[0]; s[0] = s[3]; s[3] = c;
42   c = s[1]; s[1] = s[2]; s[2] = c;
43   return result;
44 }
45
46 ImageSet::ImageSet() : _nb_pics(-1), _nb_obj(0), _width(-1), _height(-1),
47                        _pixel_maps(0), _pixels(0), _labels(0), _used_picture(0) { }
48
49 ImageSet::~ImageSet() {
50   if(_pixel_maps) _pixel_maps->del_ref();
51   delete[] _pixels;
52   delete[] _labels;
53   delete[] _used_picture;
54 }
55
56 void ImageSet::reset_used_pictures() {
57   for(int p = 0; p < _nb_pics; p++) _used_picture[p] = false;
58 }
59
60 int ImageSet::nb_unused_pictures() {
61   int n = 0;
62   for(int p = 0; p < _nb_pics; p++) if(!_used_picture[p]) n++;
63   return n;
64 }
65
66 int ImageSet::pick_unused_picture() {
67   int m;
68   do { m = int(drand48() * _nb_pics); } while(_used_picture[m]);
69   _used_picture[m] = true;
70   return m;
71 }
72
73 void ImageSet::load_mnist_format(char *picture_file_name, char *label_file_name) {
74   unsigned int magic;
75
76   ifstream picture_is(picture_file_name);
77
78   if(picture_is.fail()) {
79     cerr << "Can not open file [" << picture_file_name << "].\n";
80     exit(1);
81   }
82
83   magic = read_high_endian_int(picture_is);
84   if(magic != mnist_pictures_magic) {
85     cerr << "Invalid magic for picture, file [" << picture_file_name << "] number [" << magic << "]\n";
86     exit(1);
87   }
88
89   _nb_pics = read_high_endian_int(picture_is);
90   _width = read_high_endian_int(picture_is);
91   _height = read_high_endian_int(picture_is);
92
93   ifstream label_is(label_file_name);
94   if(label_is.fail()) {
95     cerr << "Can not open file [" << label_file_name << "].\n";
96     exit(1);
97   }
98
99   magic = read_high_endian_int(label_is);
100   if(magic != mnist_labels_magic) {
101     cerr << "Invalid magic for labels, file [" << label_file_name << "] number [" << magic << "]\n";
102     exit(1);
103   }
104
105   int nb_pics_labels = read_high_endian_int(label_is);
106
107   if(nb_pics_labels != _nb_pics) {
108     cerr << "Inconsistency between the number of pictures in [" << picture_file_name << "] (" << _nb_pics << ")"
109          << " and the number of labels in [" << label_file_name << "] (" << nb_pics_labels << ").\n";
110     exit(1);
111   }
112
113   PixelMaps *pm = new PixelMaps(_nb_pics * _width * _height);
114   _pixel_maps = pm->add_ref();
115   _pixels = new unsigned char *[_nb_pics];
116   _labels = new unsigned char[_nb_pics];
117   _used_picture = new bool[_nb_pics];
118
119   picture_is.read((char *) _pixel_maps->_core, _nb_pics * _width * _height);
120   label_is.read((char *) _labels, _nb_pics);
121
122   for(int i = 0; i < _nb_pics * _width * _height; i++) _pixel_maps->_core[i] = 255 - _pixel_maps->_core[i];
123
124   _nb_obj = 0;
125   for(int n = 0; n < _nb_pics; n++) {
126     _pixels[n] = _pixel_maps->_core + n * _width * _height;
127     if(_labels[n] > _nb_obj) _nb_obj = _labels[n];
128   }
129   _nb_obj++;
130
131   reset_used_pictures();
132 }
133
134 void ImageSet::sample_among_unused_pictures(ImageSet &is, int nb) {
135   if(nb > is.nb_unused_pictures()) {
136     cerr << "Trying to extract " << nb << " pictures from a set of " << is.nb_unused_pictures() << "\n";
137     exit(1);
138   }
139
140   _nb_pics = nb;
141   _width = is._width;
142   _height = is._height;
143   _nb_obj = is._nb_obj;
144   _pixel_maps = is._pixel_maps->add_ref();
145   _pixels = new unsigned char *[_nb_pics];
146   _labels = new unsigned char[_nb_pics];
147   _used_picture = new bool[_nb_pics];
148   for(int n = 0; n < _nb_pics; n++) {
149     int m = is.pick_unused_picture();
150     _pixels[n] = is._pixels[m];
151     _labels[n] = is._labels[m];
152   }
153
154   reset_used_pictures();
155 }