automatic commit
[mlp.git] / images.cc
diff --git a/images.cc b/images.cc
new file mode 100644 (file)
index 0000000..e1b07b4
--- /dev/null
+++ b/images.cc
@@ -0,0 +1,153 @@
+/*
+ *  mlp-mnist is an implementation of a multi-layer neural network.
+ *
+ *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
+ *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
+ *
+ *  This file is part of mlp-mnist.
+ *
+ *  mlp-mnist is free software: you can redistribute it and/or modify
+ *  it under the terms of the GNU General Public License version 3 as
+ *  published by the Free Software Foundation.
+ *
+ *  mlp-mnist is distributed in the hope that it will be useful, but
+ *  WITHOUT ANY WARRANTY; without even the implied warranty of
+ *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ *  General Public License for more details.
+ *
+ *  You should have received a copy of the GNU General Public License
+ *  along with mlp-mnist.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+#include "images.h"
+#include <stdlib.h>
+
+PixelMaps::PixelMaps(int size) : _nb_ref(0),_core(new unsigned char[size]) {}
+PixelMaps::~PixelMaps() { delete[] _core; }
+PixelMaps *PixelMaps::add_ref() { _nb_ref++; return this; }
+void PixelMaps::del_ref() { _nb_ref--; if(_nb_ref == 0) delete this; }
+
+const unsigned int mnist_pictures_magic = 0x00000803;
+const unsigned int mnist_labels_magic = 0x00000801;
+
+inline unsigned int read_high_endian_int(istream &is) {
+  unsigned int result;
+  char *s = (char *) &result;
+  char c;
+  is.read(s, sizeof(result));
+  c = s[0]; s[0] = s[3]; s[3] = c;
+  c = s[1]; s[1] = s[2]; s[2] = c;
+  return result;
+}
+
+ImageSet::ImageSet() : _nb_pics(-1), _nb_obj(0), _width(-1), _height(-1),
+                       _pixel_maps(0), _pixels(0), _labels(0), _used_picture(0) { }
+
+ImageSet::~ImageSet() {
+  if(_pixel_maps) _pixel_maps->del_ref();
+  delete[] _pixels;
+  delete[] _labels;
+  delete[] _used_picture;
+}
+
+void ImageSet::reset_used_pictures() {
+  for(int p = 0; p < _nb_pics; p++) _used_picture[p] = false;
+}
+
+int ImageSet::nb_unused_pictures() {
+  int n = 0;
+  for(int p = 0; p < _nb_pics; p++) if(!_used_picture[p]) n++;
+  return n;
+}
+
+int ImageSet::pick_unused_picture() {
+  int m;
+  do { m = int(drand48() * _nb_pics); } while(_used_picture[m]);
+  _used_picture[m] = true;
+  return m;
+}
+
+void ImageSet::extract_unused_pictures(ImageSet &is, int nb) {
+  if(nb > is.nb_unused_pictures()) {
+    cerr << "Trying to extract " << nb << " pictures from a set of " << is.nb_unused_pictures() << "\n";
+    exit(1);
+  }
+
+  _nb_pics = nb;
+  _width = is._width;
+  _height = is._height;
+  _nb_obj = is._nb_obj;
+  _pixel_maps = is._pixel_maps->add_ref();
+  _pixels = new unsigned char *[_nb_pics];
+  _labels = new unsigned char[_nb_pics];
+  _used_picture = new bool[_nb_pics];
+  for(int n = 0; n < _nb_pics; n++) {
+    int m = is.pick_unused_picture();
+    _pixels[n] = is._pixels[m];
+    _labels[n] = is._labels[m];
+  }
+
+  reset_used_pictures();
+}
+
+void ImageSet::load_mnist_format(char *picture_file_name, char *label_file_name) {
+  unsigned int magic;
+
+  ifstream picture_is(picture_file_name);
+
+  if(picture_is.fail()) {
+    cerr << "Can not open file [" << picture_file_name << "].\n";
+    exit(1);
+  }
+
+  magic = read_high_endian_int(picture_is);
+  if(magic != mnist_pictures_magic) {
+    cerr << "Invalid magic for picture, file [" << picture_file_name << "] number [" << magic << "]\n";
+    exit(1);
+  }
+
+  _nb_pics = read_high_endian_int(picture_is);
+  _width = read_high_endian_int(picture_is);
+  _height = read_high_endian_int(picture_is);
+
+  ifstream label_is(label_file_name);
+  if(label_is.fail()) {
+    cerr << "Can not open file [" << label_file_name << "].\n";
+    exit(1);
+  }
+
+  magic = read_high_endian_int(label_is);
+  if(magic != mnist_labels_magic) {
+    cerr << "Invalid magic for labels, file [" << label_file_name << "] number [" << magic << "]\n";
+    exit(1);
+  }
+
+  int nb_pics_labels = read_high_endian_int(label_is);
+
+  if(nb_pics_labels != _nb_pics) {
+    cerr << "Inconsistency between the number of pictures in [" << picture_file_name << "] (" << _nb_pics << ")"
+         << " and the number of labels in [" << label_file_name << "] (" << nb_pics_labels << ").\n";
+    exit(1);
+  }
+
+  PixelMaps *pm = new PixelMaps(_nb_pics * _width * _height);
+  _pixel_maps = pm->add_ref();
+  _pixels = new unsigned char *[_nb_pics];
+  _labels = new unsigned char[_nb_pics];
+  _used_picture = new bool[_nb_pics];
+
+  picture_is.read((char *) _pixel_maps->_core, _nb_pics * _width * _height);
+  label_is.read((char *) _labels, _nb_pics);
+
+  for(int i = 0; i < _nb_pics * _width * _height; i++) _pixel_maps->_core[i] = 255 - _pixel_maps->_core[i];
+
+  _nb_obj = 0;
+  for(int n = 0; n < _nb_pics; n++) {
+    _pixels[n] = _pixel_maps->_core + n * _width * _height;
+    if(_labels[n] > _nb_obj) _nb_obj = _labels[n];
+  }
+  _nb_obj++;
+
+  reset_used_pictures();
+}