Update.
[flatland.git] / test.py
diff --git a/test.py b/test.py
index 0a065fa..ac3eb7e 100755 (executable)
--- a/test.py
+++ b/test.py
@@ -2,13 +2,57 @@
 
 import torch
 import torchvision
-from torchvision import datasets
+import argparse
 
 from _ext import flatland
 
 ######################################################################
 
-def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)):
+parser = argparse.ArgumentParser(
+    description='Dummy test of the flatland sequence generation.',
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed, < 0 is no seeding')
+
+parser.add_argument('--width',
+                    type = int, default = 80,
+                    help = 'Image width')
+
+parser.add_argument('--height',
+                    type = int, default = 80,
+                    help = 'Image height')
+
+parser.add_argument('--nb_shapes',
+                    type = int, default = 10,
+                    help = 'Image height')
+
+parser.add_argument('--nb_sequences',
+                    type = int, default = 1,
+                    help = 'How many sequences to generate')
+
+parser.add_argument('--nb_images_per_sequences',
+                    type = int, default = 3,
+                    help = 'How many images per sequence')
+
+parser.add_argument('--randomize_colors',
+                    action='store_true', default=False,
+                    help = 'Should the shapes be of different colors')
+
+parser.add_argument('--randomize_shape_size',
+                    action='store_true', default=False,
+                    help = 'Should the shapes be of different size')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+######################################################################
+
+def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)):
     from PIL import Image
 
     nb_sequences = x.size(0)
@@ -26,9 +70,7 @@ def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)):
                               gap + nb_sequences * (height + gap),
                               gap + nb_images_per_sequences * (width + gap))
 
-    result[0].fill_(gap_color[0])
-    result[1].fill_(gap_color[1])
-    result[2].fill_(gap_color[2])
+    result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result))
 
     for s in range(0, nb_sequences):
         for i in range(0, nb_images_per_sequences):
@@ -42,6 +84,11 @@ def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)):
 
 ######################################################################
 
-x = flatland.generate_sequence(10, 6, 80, 80, True, True)
+x = flatland.generate_sequence(args.nb_sequences,
+                               args.nb_images_per_sequences,
+                               args.height, args.width,
+                               args.nb_shapes,
+                               args.randomize_colors,
+                               args.randomize_shape_size)
 
-sequences_to_image(x, gap = 2, gap_color = (0, 0, 0)).save('sequences.png')
+sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')