X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=flatland.git;a=blobdiff_plain;f=test.py;h=e058269e06d49ffb4543c88753d99ec7ff0275e9;hp=bf51360fbcf7a014071d3c0c64e746734767601d;hb=90d1c5704c30e7f1d041e32eacbc2893741110e1;hpb=c2a7c7d6dfec8bd1eca29406d160cce5b4a35209 diff --git a/test.py b/test.py index bf51360..e058269 100755 --- a/test.py +++ b/test.py @@ -2,11 +2,57 @@ import torch import torchvision -from torchvision import datasets +import argparse + +from _ext import flatland + +###################################################################### + +parser = argparse.ArgumentParser( + description = 'Dummy test of the flatland sequence generation.', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed, < 0 is no seeding') + +parser.add_argument('--width', + type = int, default = 80, + help = 'Image width') + +parser.add_argument('--height', + type = int, default = 80, + help = 'Image height') + +parser.add_argument('--nb_shapes', + type = int, default = 10, + help = 'Image height') + +parser.add_argument('--nb_sequences', + type = int, default = 1, + help = 'How many sequences to generate') + +parser.add_argument('--nb_images_per_sequences', + type = int, default = 3, + help = 'How many images per sequence') + +parser.add_argument('--randomize_colors', + action='store_true', default=False, + help = 'Should the shapes be of different colors') + +parser.add_argument('--randomize_shape_size', + action='store_true', default=False, + help = 'Should the shapes be of different size') + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) ###################################################################### -def sequences_to_image(x): +def sequences_to_image(x, gap = 1, gap_color = (0, 128, 255)): from PIL import Image nb_sequences = x.size(0) @@ -19,20 +65,18 @@ def sequences_to_image(x): height = x.size(3) width = x.size(4) - gap = 1 - gap_color = (0, 128, 255) result = torch.ByteTensor(nb_channels, gap + nb_sequences * (height + gap), gap + nb_images_per_sequences * (width + gap)) - result[0].fill_(gap_color[0]) - result[1].fill_(gap_color[1]) - result[2].fill_(gap_color[2]) + result.copy_(torch.Tensor(gap_color).view(-1, 1, 1).expand_as(result)) for s in range(0, nb_sequences): for i in range(0, nb_images_per_sequences): - result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i]) + result.narrow(1, gap + s * (height + gap), height) \ + .narrow(2, gap + i * (width + gap), width) \ + .copy_(x[s][i]) result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy() @@ -40,10 +84,12 @@ def sequences_to_image(x): ###################################################################### -from _ext import mylib - -x = torch.ByteTensor() - -mylib.generate_sequence(10, x) +x = flatland.generate_sequence(False, + args.nb_sequences, + args.nb_images_per_sequences, + args.height, args.width, + args.nb_shapes, + args.randomize_colors, + args.randomize_shape_size) -sequences_to_image(x).save('sequences.png') +sequences_to_image(x, gap = 1, gap_color = (0, 0, 0)).save('sequences.png')