ac3eb7e2506e809d640658094587e1df3a61db95
[flatland.git] / test.py
1 #!/usr/bin/env python-for-pytorch
2
3 import torch
4 import torchvision
5 import argparse
6
7 from _ext import flatland
8
9 ######################################################################
10
11 parser = argparse.ArgumentParser(
12     description='Dummy test of the flatland sequence generation.',
13     formatter_class=argparse.ArgumentDefaultsHelpFormatter
14 )
15
16 parser.add_argument('--seed',
17                     type = int, default = 0,
18                     help = 'Random seed, < 0 is no seeding')
19
20 parser.add_argument('--width',
21                     type = int, default = 80,
22                     help = 'Image width')
23
24 parser.add_argument('--height',
25                     type = int, default = 80,
26                     help = 'Image height')
27
28 parser.add_argument('--nb_shapes',
29                     type = int, default = 10,
30                     help = 'Image height')
31
32 parser.add_argument('--nb_sequences',
33                     type = int, default = 1,
34                     help = 'How many sequences to generate')
35
36 parser.add_argument('--nb_images_per_sequences',
37                     type = int, default = 3,
38                     help = 'How many images per sequence')
39
40 parser.add_argument('--randomize_colors',
41                     action='store_true', default=False,
42                     help = 'Should the shapes be of different colors')
43
44 parser.add_argument('--randomize_shape_size',
45                     action='store_true', default=False,
46                     help = 'Should the shapes be of different size')
47
48 args = parser.parse_args()
49
50 if args.seed >= 0:
51     torch.manual_seed(args.seed)
52
53 ######################################################################
54
55 def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
56     from PIL import Image
57
58     nb_sequences = x.size(0)
59     nb_images_per_sequences = x.size(1)
60     nb_channels = 3
61
62     if x.size(2) != nb_channels:
63         print('Can only handle 3 channel tensors.')
64         exit(1)
65
66     height = x.size(3)
67     width = x.size(4)
68
69     result = torch.ByteTensor(nb_channels,
70                               gap + nb_sequences * (height + gap),
71                               gap + nb_images_per_sequences * (width + gap))
72
73     result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
74
75     for s in range(0, nb_sequences):
76         for i in range(0, nb_images_per_sequences):
77             result.narrow(1, gap + s * (height + gap), height) \
78                   .narrow(2, gap + i * (width + gap), width) \
79                   .copy_(x[s][i])
80
81     result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
82
83     return Image.fromarray(result_numpy, 'RGB')
84
85 ######################################################################
86
87 x = flatland.generate_sequence(args.nb_sequences,
88                                args.nb_images_per_sequences,
89                                args.height, args.width,
90                                args.nb_shapes,
91                                args.randomize_colors,
92                                args.randomize_shape_size)
93
94 sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')