1 #!/usr/bin/env python-for-pytorch
7 from _ext import flatland
9 ######################################################################
11 parser = argparse.ArgumentParser(
12 description = 'Dummy test of the flatland sequence generation.',
13 formatter_class = argparse.ArgumentDefaultsHelpFormatter
16 parser.add_argument('--seed',
17 type = int, default = 0,
18 help = 'Random seed, < 0 is no seeding')
20 parser.add_argument('--width',
21 type = int, default = 80,
24 parser.add_argument('--height',
25 type = int, default = 80,
26 help = 'Image height')
28 parser.add_argument('--nb_shapes',
29 type = int, default = 10,
30 help = 'Image height')
32 parser.add_argument('--nb_sequences',
33 type = int, default = 1,
34 help = 'How many sequences to generate')
36 parser.add_argument('--nb_images_per_sequences',
37 type = int, default = 3,
38 help = 'How many images per sequence')
40 parser.add_argument('--randomize_colors',
41 action='store_true', default=False,
42 help = 'Should the shapes be of different colors')
44 parser.add_argument('--randomize_shape_size',
45 action='store_true', default=False,
46 help = 'Should the shapes be of different size')
48 args = parser.parse_args()
51 torch.manual_seed(args.seed)
53 ######################################################################
55 def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
58 nb_sequences = x.size(0)
59 nb_images_per_sequences = x.size(1)
62 if x.size(2) != nb_channels:
63 print('Can only handle 3 channel tensors.')
69 result = torch.ByteTensor(nb_channels,
70 gap + nb_sequences * (height + gap),
71 gap + nb_images_per_sequences * (width + gap))
73 result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
75 for s in range(0, nb_sequences):
76 for i in range(0, nb_images_per_sequences):
77 result.narrow(1, gap + s * (height + gap), height) \
78 .narrow(2, gap + i * (width + gap), width) \
81 result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
83 return Image.fromarray(result_numpy, 'RGB')
85 ######################################################################
87 x = flatland.generate_sequence(False,
89 args.nb_images_per_sequences,
90 args.height, args.width,
92 args.randomize_colors,
93 args.randomize_shape_size)
95 sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')