X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=flatland.git;a=blobdiff_plain;f=flatland-test.py;fp=flatland-test.py;h=905e842d0036e497eb074c1ceb2fc5a953baa848;hp=0000000000000000000000000000000000000000;hb=c1f1040936d977cd2b3a276c725e223198377d2a;hpb=90d1c5704c30e7f1d041e32eacbc2893741110e1 diff --git a/flatland-test.py b/flatland-test.py new file mode 100755 index 0000000..905e842 --- /dev/null +++ b/flatland-test.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python-for-pytorch + +# +# flatland is a simple 2d physical simulator +# +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ +# Written by Francois Fleuret +# +# This file is part of flatland +# +# flatland is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# flatland is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with flatland. If not, see . +# + +import torch +import torchvision +import argparse + +from _ext import flatland + +###################################################################### + +parser = argparse.ArgumentParser( + description = 'Dummy test of the flatland sequence generation.', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed, < 0 is no seeding') + +parser.add_argument('--width', + type = int, default = 80, + help = 'Image width') + +parser.add_argument('--height', + type = int, default = 80, + help = 'Image height') + +parser.add_argument('--nb_shapes', + type = int, default = 10, + help = 'Image height') + +parser.add_argument('--nb_sequences', + type = int, default = 1, + help = 'How many sequences to generate') + +parser.add_argument('--nb_images_per_sequences', + type = int, default = 3, + help = 'How many images per sequence') + +parser.add_argument('--randomize_colors', + action='store_true', default=False, + help = 'Should the shapes be of different colors') + +parser.add_argument('--randomize_shape_size', + action='store_true', default=False, + help = 'Should the shapes be of different size') + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +###################################################################### + +def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)): + from PIL import Image + + nb_sequences = x.size(0) + nb_images_per_sequences = x.size(1) + nb_channels = 3 + + if x.size(2) != nb_channels: + print('Can only handle 3 channel tensors.') + exit(1) + + height = x.size(3) + width = x.size(4) + + result = torch.ByteTensor(nb_channels, + gap + nb_sequences * (height + gap), + gap + nb_images_per_sequences * (width + gap)) + + result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result)) + + for s in range(0, nb_sequences): + for i in range(0, nb_images_per_sequences): + result.narrow(1, gap + s * (height + gap), height) \ + .narrow(2, gap + i * (width + gap), width) \ + .copy_(x[s][i]) + + result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy() + + return Image.fromarray(result_numpy, 'RGB') + +###################################################################### + +x = flatland.generate_sequence(False, + args.nb_sequences, + args.nb_images_per_sequences, + args.height, args.width, + args.nb_shapes, + args.randomize_colors, + args.randomize_shape_size) + +sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')