Added headers.
[flatland.git] / flatland-test.py
diff --git a/flatland-test.py b/flatland-test.py
new file mode 100755 (executable)
index 0000000..905e842
--- /dev/null
@@ -0,0 +1,116 @@
+#!/usr/bin/env python-for-pytorch
+
+#
+#  flatland is a simple 2d physical simulator
+#
+#  Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
+#  Written by Francois Fleuret <francois.fleuret@idiap.ch>
+#
+#  This file is part of flatland
+#
+#  flatland is free software: you can redistribute it and/or modify it
+#  under the terms of the GNU General Public License version 3 as
+#  published by the Free Software Foundation.
+#
+#  flatland is distributed in the hope that it will be useful, but
+#  WITHOUT ANY WARRANTY; without even the implied warranty of
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+#  General Public License for more details.
+#
+#  You should have received a copy of the GNU General Public License
+#  along with flatland.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+import torch
+import torchvision
+import argparse
+
+from _ext import flatland
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description = 'Dummy test of the flatland sequence generation.',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed, < 0 is no seeding')
+
+parser.add_argument('--width',
+                    type = int, default = 80,
+                    help = 'Image width')
+
+parser.add_argument('--height',
+                    type = int, default = 80,
+                    help = 'Image height')
+
+parser.add_argument('--nb_shapes',
+                    type = int, default = 10,
+                    help = 'Image height')
+
+parser.add_argument('--nb_sequences',
+                    type = int, default = 1,
+                    help = 'How many sequences to generate')
+
+parser.add_argument('--nb_images_per_sequences',
+                    type = int, default = 3,
+                    help = 'How many images per sequence')
+
+parser.add_argument('--randomize_colors',
+                    action='store_true', default=False,
+                    help = 'Should the shapes be of different colors')
+
+parser.add_argument('--randomize_shape_size',
+                    action='store_true', default=False,
+                    help = 'Should the shapes be of different size')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+######################################################################
+
+def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
+    from PIL import Image
+
+    nb_sequences = x.size(0)
+    nb_images_per_sequences = x.size(1)
+    nb_channels = 3
+
+    if x.size(2) != nb_channels:
+        print('Can only handle 3 channel tensors.')
+        exit(1)
+
+    height = x.size(3)
+    width = x.size(4)
+
+    result = torch.ByteTensor(nb_channels,
+                              gap + nb_sequences * (height + gap),
+                              gap + nb_images_per_sequences * (width + gap))
+
+    result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
+
+    for s in range(0, nb_sequences):
+        for i in range(0, nb_images_per_sequences):
+            result.narrow(1, gap + s * (height + gap), height) \
+                  .narrow(2, gap + i * (width + gap), width) \
+                  .copy_(x[s][i])
+
+    result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
+
+    return Image.fromarray(result_numpy, 'RGB')
+
+######################################################################
+
+x = flatland.generate_sequence(False,
+                               args.nb_sequences,
+                               args.nb_images_per_sequences,
+                               args.height, args.width,
+                               args.nb_shapes,
+                               args.randomize_colors,
+                               args.randomize_shape_size)
+
+sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')