import torch
import torchvision
-from torchvision import datasets
+import argparse
+
+from _ext import flatland
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description='Dummy test of the flatland sequence generation.',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed, < 0 is no seeding')
+
+parser.add_argument('--width',
+ type = int, default = 80,
+ help = 'Image width')
+
+parser.add_argument('--height',
+ type = int, default = 80,
+ help = 'Image height')
+
+parser.add_argument('--nb_shapes',
+ type = int, default = 10,
+ help = 'Image height')
+
+parser.add_argument('--nb_sequences',
+ type = int, default = 1,
+ help = 'How many sequences to generate')
+
+parser.add_argument('--nb_images_per_sequences',
+ type = int, default = 3,
+ help = 'How many images per sequence')
+
+parser.add_argument('--randomize_colors',
+ action='store_true', default=False,
+ help = 'Should the shapes be of different colors')
+
+parser.add_argument('--randomize_shape_size',
+ action='store_true', default=False,
+ help = 'Should the shapes be of different size')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
######################################################################
-def sequences_to_image(x):
+def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
from PIL import Image
nb_sequences = x.size(0)
height = x.size(3)
width = x.size(4)
- gap = 1
- gap_color = (0, 128, 255)
result = torch.ByteTensor(nb_channels,
gap + nb_sequences * (height + gap),
gap + nb_images_per_sequences * (width + gap))
- result[0].fill_(gap_color[0])
- result[1].fill_(gap_color[1])
- result[2].fill_(gap_color[2])
+ result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
for s in range(0, nb_sequences):
for i in range(0, nb_images_per_sequences):
- result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i])
+ result.narrow(1, gap + s * (height + gap), height) \
+ .narrow(2, gap + i * (width + gap), width) \
+ .copy_(x[s][i])
result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
######################################################################
-from _ext import flatland
-
-x = torch.ByteTensor()
-
-flatland.generate_sequence(10, x)
+x = flatland.generate_sequence(args.nb_sequences,
+ args.nb_images_per_sequences,
+ args.height, args.width,
+ args.nb_shapes,
+ args.randomize_colors,
+ args.randomize_shape_size)
-sequences_to_image(x).save('sequences.png')
+sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')