2c269ba0eb9dcc35dcb15ef7ca6df2bc9e0fddc3
[flatland.git] / flatland-test.py
1 #!/usr/bin/env python
2
3 #
4 #  flatland is a simple 2d physical simulator
5 #
6 #  Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
7 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8 #
9 #  This file is part of flatland
10 #
11 #  flatland is free software: you can redistribute it and/or modify it
12 #  under the terms of the GNU General Public License version 3 as
13 #  published by the Free Software Foundation.
14 #
15 #  flatland is distributed in the hope that it will be useful, but
16 #  WITHOUT ANY WARRANTY; without even the implied warranty of
17 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18 #  General Public License for more details.
19 #
20 #  You should have received a copy of the GNU General Public License
21 #  along with flatland.  If not, see <http://www.gnu.org/licenses/>.
22 #
23
24 import torch
25 import torchvision
26 import argparse
27
28 import flatland
29
30 ######################################################################
31
32 parser = argparse.ArgumentParser(
33     description = 'Dummy test of the flatland sequence generation.',
34     formatter_class = argparse.ArgumentDefaultsHelpFormatter
35 )
36
37 parser.add_argument('--seed',
38                     type = int, default = 0,
39                     help = 'Random seed, < 0 is no seeding')
40
41 parser.add_argument('--width',
42                     type = int, default = 80,
43                     help = 'Image width')
44
45 parser.add_argument('--height',
46                     type = int, default = 80,
47                     help = 'Image height')
48
49 parser.add_argument('--nb_shapes',
50                     type = int, default = 8,
51                     help = 'Image height')
52
53 parser.add_argument('--nb_sequences',
54                     type = int, default = 8,
55                     help = 'How many sequences to generate')
56
57 parser.add_argument('--nb_images_per_sequences',
58                     type = int, default = 16,
59                     help = 'How many images per sequence')
60
61 parser.add_argument('--randomize_colors',
62                     action='store_true', default=True,
63                     help = 'Should the shapes be of different colors')
64
65 parser.add_argument('--randomize_shape_size',
66                     action='store_true', default=False,
67                     help = 'Should the shapes be of different size')
68
69 args = parser.parse_args()
70
71 if args.seed >= 0:
72     torch.manual_seed(args.seed)
73
74 ######################################################################
75
76 def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
77     from PIL import Image
78
79     nb_sequences = x.size(0)
80     nb_images_per_sequences = x.size(1)
81     nb_channels = 3
82
83     if x.size(2) != nb_channels:
84         print('Can only handle 3 channel tensors.')
85         exit(1)
86
87     height = x.size(3)
88     width = x.size(4)
89
90     result = torch.ByteTensor(nb_channels,
91                               gap + nb_sequences * (height + gap),
92                               gap + nb_images_per_sequences * (width + gap))
93
94     result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
95
96     for s in range(0, nb_sequences):
97         for i in range(0, nb_images_per_sequences):
98             result.narrow(1, gap + s * (height + gap), height) \
99                   .narrow(2, gap + i * (width + gap), width) \
100                   .copy_(x[s][i])
101
102     result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
103
104     return Image.fromarray(result_numpy, 'RGB')
105
106 ######################################################################
107
108 x = flatland.generate_sequence(False,
109                                args.nb_sequences,
110                                args.nb_images_per_sequences,
111                                args.height, args.width,
112                                args.nb_shapes,
113                                args.randomize_shape_size,
114                                args.randomize_colors)
115
116 sequences_to_image(x, gap = 3, gap_color = (0, 150, 200)).save('sequences.png')
117
118 print('Saved sequences.png.')